In [23]:
import torch
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from tqdm import tqdm
from torch.nn import functional as F
import matplotlib.pyplot as plt

from torch.optim.lr_scheduler import LambdaLR

### Load Data

In [24]:
df = pd.read_csv("./Students_Grading_Dataset.csv")
# df

# for att in df:
#     print(f"{att}: {df[att][0]}")

## Manually select Cols (attribute)

In [25]:

unimportant_attribute = ['Student_ID', 'First_Name', 'Last_Name', 'Email']

filtered_df = df.drop(unimportant_attribute, axis=1)
filtered_df

Unnamed: 0,Gender,Age,Department,Attendance (%),Midterm_Score,Final_Score,Assignments_Avg,Quizzes_Avg,Participation_Score,Projects_Score,Total_Score,Grade,Study_Hours_per_Week,Extracurricular_Activities,Internet_Access_at_Home,Parent_Education_Level,Family_Income_Level,Stress_Level (1-10),Sleep_Hours_per_Night
0,Female,22,Engineering,52.29,55.03,57.82,84.22,74.06,3.99,85.90,56.09,F,6.2,No,Yes,High School,Medium,5,4.7
1,Male,18,Engineering,97.27,97.23,45.80,,94.24,8.32,55.65,50.64,A,19.0,No,Yes,,Medium,4,9.0
2,Male,24,Business,57.19,67.05,93.68,67.70,85.70,5.05,73.79,70.30,D,20.7,No,Yes,Master's,Low,6,6.2
3,Female,24,Mathematics,95.15,47.79,80.63,66.06,93.51,6.54,92.12,61.63,A,24.8,Yes,Yes,High School,High,3,6.7
4,Female,23,CS,54.18,46.59,78.89,96.85,83.70,5.97,68.42,66.13,F,15.4,Yes,Yes,High School,High,2,7.1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4995,Male,19,Business,,82.15,60.33,80.09,99.32,5.00,58.42,85.21,D,25.5,No,Yes,High School,Low,10,8.3
4996,Male,19,Business,65.11,86.31,49.80,,88.08,2.79,60.87,95.96,C,5.0,No,Yes,,Medium,4,4.0
4997,Female,24,CS,87.54,63.55,64.21,94.28,50.19,3.13,82.65,54.25,A,24.8,Yes,No,High School,Medium,4,6.3
4998,Male,23,CS,92.56,79.79,94.28,81.20,61.18,0.40,94.29,55.84,A,16.1,Yes,Yes,Bachelor's,Low,1,8.4


In [26]:
category_vars = ['Gender', 'Department', 'Grade', 'Extracurricular_Activities', 'Internet_Access_at_Home', 'Parent_Education_Level', 'Family_Income_Level']
numerical_score_vars = ['Attendance (%)', 'Midterm_Score', 'Final_Score', 'Assignments_Avg', 'Quizzes_Avg', 'Participation_Score', 'Projects_Score', 'Total_Score', 'Stress_Level (1-10)']
numerical_scalar_vars = list(set(filtered_df.columns) - set(category_vars) - set(numerical_score_vars))

## Separate rows with-Nan and without-Nan

In [27]:
nan_rows = filtered_df.isna().any(axis=1)

# Nan rows
df_nan = filtered_df[nan_rows]
print(f"row with Nan: {df_nan.shape}")
# Complete rows
df_complete = filtered_df[~nan_rows]
print(f"row without Nan: {df_complete.shape}")

row with Nan: (2419, 19)
row without Nan: (2581, 19)


In [28]:
# split df_complete into train/valid
# data_amount = int(len(df_complete) * 0.8)
# df_train = df_complete.iloc[:data_amount, :]
# df_valid = df_complete.iloc[data_amount:, :]

df_train, df_valid, _, _ = train_test_split(df_complete, df_complete, test_size=0.3, random_state=0)

print(f"df_train: {df_train.shape}")
print(f"df_valid: {df_valid.shape}")

df_train: (1806, 19)
df_valid: (775, 19)


## Preprocessing: 
1. category to numerical
2. max-min norm

In [29]:
def category_to_numerical(data):
    le = LabelEncoder()
    le.fit(data)
    num_data = le.transform(data)
    
    return num_data, le

# def max_min_norm(data, train_params = None, process_type = 'train'):
    
#     if process_type == 'train':
#         data_max = np.max(data)
#         data_min = np.min(data)
#     else:
#         data_max = train_params['Age'][0]
#         data_min = train_params['Age'][1]
        
#     norm_data = (data - data_min) / (data_max - data_min + 1e-3)    
    
#     if process_type == 'train':
#         return norm_data, data_max, data_min
#     else:
#         return norm_data
    

def max_min_norm_score(data, train_params = None, process_type = 'train'):
    
    if process_type == 'train':
        data_max = 100
        data_min = 0
    else:
        data_max = 100
        data_min = 0
        
    norm_data = (data - data_min) / (data_max - data_min)    
    
    if process_type == 'train':
        return norm_data, data_max, data_min
    else:
        return norm_data
    
def max_min_norm_scalar(data, train_params = None, process_type = 'train'):
    
    if process_type == 'train':
        data_max = 10
        data_min = 0
    else:
        data_max = 10
        data_min = 0
        
    norm_data = (data - data_min) / (data_max - data_min)    
    
    if process_type == 'train':
        return norm_data, data_max, data_min
    else:
        return norm_data

    
def preprocessing(df, train_params = None, process_type = 'train'):
    
    new_df = pd.DataFrame()
    
    if process_type == 'train':
        train_params = {}
        category_var_len = {}

    # Category 
    for cat_name in category_vars:
        cat_var = df[cat_name]
        if process_type == 'train':
            cat_var, le = category_to_numerical(cat_var)
            train_params[f'{cat_name}_le'] = le
            category_var_len[f'{cat_name}'] = len(np.unique(cat_var))
        else:
            cat_var = train_params[f'{cat_name}_le'].transform(cat_var)
        new_df[f'{cat_name}'] = cat_var
    
    # Numerical score
    for num_name in numerical_score_vars:
        num_var = df[num_name]
        if process_type == 'train':
            num_var, data_max, data_min = max_min_norm_score(num_var, process_type = 'train')
            train_params[num_name] = [data_max, data_min]
        else:
            num_var = max_min_norm_score(num_var, train_params, process_type = 'valid')
        new_df[num_name] = num_var.values
    
    # Numerical scalar
    for num_name in numerical_scalar_vars:
        num_var = df[num_name]
        num_var = np.log(num_var)
        if process_type == 'train':
            num_var, data_max, data_min = max_min_norm_scalar(num_var, process_type = 'train')
            train_params[num_name] = [data_max, data_min]
        else:
            num_var = max_min_norm_scalar(num_var, train_params, process_type = 'valid')
        new_df[num_name] = num_var.values
        
        
    if process_type == 'train':
        return new_df, train_params, category_var_len
    else:
        return new_df


In [30]:
processed_df_train, train_params, category_var_len = preprocessing(df_train, process_type = 'train')
# train_params
print(f"category_var_len: {category_var_len}")
print(f"processed_df_train: {processed_df_train.shape}")
processed_df_train.head()

category_var_len: {'Gender': 2, 'Department': 4, 'Grade': 5, 'Extracurricular_Activities': 2, 'Internet_Access_at_Home': 2, 'Parent_Education_Level': 4, 'Family_Income_Level': 3}
processed_df_train: (1806, 19)


Unnamed: 0,Gender,Department,Grade,Extracurricular_Activities,Internet_Access_at_Home,Parent_Education_Level,Family_Income_Level,Attendance (%),Midterm_Score,Final_Score,Assignments_Avg,Quizzes_Avg,Participation_Score,Projects_Score,Total_Score,Stress_Level (1-10),Sleep_Hours_per_Night,Study_Hours_per_Week,Age
0,1,1,1,1,1,0,1,0.8989,0.4255,0.5045,0.5528,0.7634,0.0231,0.5783,0.7426,0.04,0.212823,0.332504,0.317805
1,1,3,2,0,1,3,2,0.5562,0.819,0.4543,0.538,0.9215,0.095,0.6371,0.8995,0.02,0.193152,0.177495,0.313549
2,1,1,4,0,1,1,1,0.6309,0.6057,0.568,0.5342,0.8288,0.0529,0.5671,0.5943,0.07,0.158924,0.293916,0.294444
3,1,2,2,1,1,1,2,0.6077,0.9869,0.455,0.5771,0.5339,0.0948,0.5516,0.8337,0.1,0.210413,0.282138,0.294444
4,0,2,1,0,1,1,2,0.959,0.8459,0.6858,0.5204,0.6389,0.0156,0.6554,0.8854,0.06,0.218605,0.184055,0.304452


In [31]:
processed_df_valid = preprocessing(df_valid, train_params, process_type = 'valid')
print(f"processed_df_valid: {processed_df_valid.shape}")
processed_df_valid.head()

processed_df_valid: (775, 19)


Unnamed: 0,Gender,Department,Grade,Extracurricular_Activities,Internet_Access_at_Home,Parent_Education_Level,Family_Income_Level,Attendance (%),Midterm_Score,Final_Score,Assignments_Avg,Quizzes_Avg,Participation_Score,Projects_Score,Total_Score,Stress_Level (1-10),Sleep_Hours_per_Night,Study_Hours_per_Week,Age
0,1,3,2,0,1,2,1,0.5154,0.6404,0.5002,0.8728,0.9699,0.0399,0.8258,0.8888,0.08,0.214007,0.283908,0.294444
1,1,1,1,1,1,2,1,0.997,0.9376,0.6142,0.5457,0.7576,0.0308,0.5397,0.6468,0.08,0.141099,0.194591,0.289037
2,1,3,2,0,1,3,0,0.6151,0.7315,0.419,0.5537,0.5564,0.077,0.7862,0.5429,0.09,0.193152,0.294969,0.317805
3,1,1,3,1,1,2,1,0.5228,0.8212,0.9962,0.5518,0.5023,0.0058,0.8334,0.6564,0.07,0.18563,0.319458,0.304452
4,1,2,0,0,1,0,2,0.9474,0.534,0.6216,0.8329,0.7903,0.0408,0.625,0.6557,0.03,0.200148,0.236085,0.313549


In [32]:
class TableDataset(Dataset):
    def __init__(self, data):
        self.data = np.array(data)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, id):
        dat = self.data[id, :]
        dat = torch.from_numpy(dat)
        return dat

In [33]:
BATCH_SIZE = 256
DROPOUT = 0.2

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


train_dataset = TableDataset(processed_df_train)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=4, shuffle=True)

valid_dataset = TableDataset(processed_df_valid)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, num_workers=4, shuffle=True)

In [34]:
# category_var_len: {'Gender': 2, 'Department': 4, 'Grade': 5, 
                #    'Extracurricular_Activities': 2, 'Internet_Access_at_Home': 2, 
                #    'Parent_Education_Level': 4, 'Family_Income_Level': 3}


In [35]:
class Head(nn.Module):
    def __init__(self, num_head, transformer_emb_size, category_emb_size):
        super().__init__()
        self.num_head = num_head
        self.up_emb_size = transformer_emb_size
        self.category_emb_size = category_emb_size
        
        # up_emb_size * 3 for qkv
        self.qkv_fn = nn.Linear(self.up_emb_size, self.up_emb_size * 3, bias=False)
        self.proj_qkv = nn.Linear(self.up_emb_size, self.up_emb_size, bias = False)
                
        # other
        self.up_dropout = nn.Dropout(DROPOUT)
        self.att_dropout = nn.Dropout(DROPOUT)
        self.down_dropout = nn.Dropout(DROPOUT)
        
    def forward(self, x):
        '''
        num_vars is seq_len in here
        
        ori_hidden_size = 1
        up_emb = 32
        
        '''        
        batch_size, num_vars, _ = x.shape
        
        # qkv: up_emb_size * 3
        x = self.qkv_fn(x) 
        q, k, v = x.split(self.up_emb_size, dim = 2) 
        
        # split head: each shape = [batch_size, num_head, num_vars(seq_len), head_size = 8]
        q = q.view(batch_size, num_vars, self.num_head, self.up_emb_size // self.num_head).transpose(1, 2)
        k = k.view(batch_size, num_vars, self.num_head, self.up_emb_size // self.num_head).transpose(1, 2)
        v = v.view(batch_size, num_vars, self.num_head, self.up_emb_size // self.num_head).transpose(1, 2)
        
        # attention matrix calculation: [batch_size, num_head, num_vars(seq_len), num_vars]
        att = (q @ k.transpose(-2,-1)) * (1 / torch.sqrt(torch.ones([1]).to(device) * k.size(-1)))
        # att = self.modify_att(att)
        att = F.softmax(att, dim = -1)
        att = self.att_dropout(att)
        
        # att matrix * V: [batch_size, num_head, num_vars(seq_len), head_size]
        out = att @ v
        out = out.transpose(1,2).contiguous().view(batch_size, num_vars, self.up_emb_size)
        out = self.proj_qkv(out)
        
        return out
    

In [36]:
class MLP(nn.Module):

    def __init__(self, transformer_emb_size):
        super().__init__()
        self.c_fc    = nn.Linear(transformer_emb_size, transformer_emb_size * 3)
        
        self.gelu    = nn.GELU(approximate='tanh')
        
        self.c_proj  = nn.Linear(transformer_emb_size * 3, transformer_emb_size)

        self.dropout = nn.Dropout(DROPOUT)
        
    def forward(self, x):
        x = self.c_fc(x)

        x = self.gelu(x)
        
        x = self.c_proj(x)

        x = self.dropout(x)
        
        return x

In [37]:
class Block(nn.Module):
    def __init__(self, num_head, transformer_emb_size, category_emb_size):
        super(Block, self).__init__()
        self.ln_head = nn.LayerNorm(transformer_emb_size)
        self.head = Head(num_head, transformer_emb_size, category_emb_size)
        self.ln_mlp = nn.LayerNorm(transformer_emb_size)
        self.mlp = MLP(transformer_emb_size)
                
    def forward(self, x):
        
        x = x + self.head(self.ln_head(x))
        x = x + self.mlp(self.ln_mlp(x))

        return x

In [38]:
class GPT(nn.Module):
    def __init__(self, layer, num_head, seq_len, transformer_emb_size, category_emb_size, dropout):
        super().__init__()
        
        self.blocks = nn.ModuleList(Block(num_head, transformer_emb_size, category_emb_size) for _ in range(layer))
        self.positions = nn.Parameter(torch.rand(seq_len, transformer_emb_size))
        self.dropout = nn.Dropout(dropout)
            
    def forward(self, x):

        # add postion info
        x = x + self.positions
        x = self.dropout(x)
        
        for block in self.blocks:
            x = block(x)

        return x

In [39]:
class tableModel(nn.Module):
    def __init__(self, category_var_len):
        super(tableModel,self).__init__()
        self.num_category_var = len(category_var_len)
        self.num_numerical_var = 12
        self.category_emb_size = 32
        self.category_dict = category_var_len
        self.batch_size = BATCH_SIZE
                
        self.mask_prob = 0.05

        self.transformer_layer = 12
        self.transformer_emb_size = 768
        self.num_head = 12
        self.seq_len = self.num_category_var + self.num_numerical_var
        
        '''# encode category vars'''
        self.encode_gender = nn.Embedding(category_var_len['Gender'] + 1, self.transformer_emb_size)
        self.encode_depart = nn.Embedding(category_var_len['Department'] + 1, self.transformer_emb_size)
        self.encode_grade = nn.Embedding(category_var_len['Grade'] + 1, self.transformer_emb_size)
        self.encode_activity = nn.Embedding(category_var_len['Extracurricular_Activities'] + 1, self.transformer_emb_size)        
        self.encode_internet = nn.Embedding(category_var_len['Internet_Access_at_Home'] + 1, self.transformer_emb_size)
        self.encode_parent = nn.Embedding(category_var_len['Parent_Education_Level'] + 1, self.transformer_emb_size)        
        self.encode_income = nn.Embedding(category_var_len['Family_Income_Level'] + 1, self.transformer_emb_size)  
        
        self.encoders = [self.encode_gender, self.encode_depart, self.encode_grade, self.encode_activity, self.encode_internet,
                        self.encode_parent, self.encode_income]
        
        self.encodes_numerical = nn.ModuleList(nn.Linear(1, self.transformer_emb_size) for _ in range( self.num_numerical_var))
        
        self.encode_dropout = nn.Dropout(DROPOUT)
        
        '''# decode category vars'''
        self.decode_gender = nn.Linear(self.transformer_emb_size, category_var_len['Gender'] + 1, bias=False)
        self.encode_gender.weight = self.decode_gender.weight        
        self.decode_depart = nn.Linear(self.transformer_emb_size, category_var_len['Department'] + 1, bias=False)
        self.encode_depart.weight = self.decode_depart.weight       
        self.decode_grade = nn.Linear(self.transformer_emb_size, category_var_len['Grade'] + 1, bias=False)
        self.encode_grade.weight = self.decode_grade.weight          
        self.decode_activity = nn.Linear(self.transformer_emb_size, category_var_len['Extracurricular_Activities'] + 1, bias=False)
        self.encode_activity.weight = self.decode_activity.weight               
        self.decode_internet = nn.Linear(self.transformer_emb_size, category_var_len['Internet_Access_at_Home'] + 1, bias=False)
        self.encode_internet.weight = self.decode_internet.weight    
        self.decode_parent = nn.Linear(self.transformer_emb_size, category_var_len['Parent_Education_Level'] + 1, bias=False)
        self.encode_parent.weight = self.decode_parent.weight    
        self.decode_income = nn.Linear(self.transformer_emb_size, category_var_len['Family_Income_Level'] + 1, bias=False)
        self.encode_income.weight = self.decode_income.weight    
        
        self.decoders = [self.decode_gender, self.decode_depart, self.decode_grade, self.decode_activity, self.decode_internet,
                        self.decode_parent, self.decode_income]
        
        self.decodes_numerical = nn.ModuleList(nn.Linear(self.transformer_emb_size, 1) for _ in range( self.num_numerical_var))
        
        self.decode_dropout = nn.Dropout(DROPOUT)

        '''# transformer'''
        self.gpt = GPT(layer = self.transformer_layer, 
                       num_head = self.num_head, 
                       seq_len = self.seq_len,
                       transformer_emb_size = self.transformer_emb_size,
                       category_emb_size = self.category_emb_size,
                       dropout = DROPOUT)
        
        
        ''' linear for numerical '''
        self.linear_numerical1 = nn.Linear(self.num_numerical_var, 32, bias = False)
        self.linear_numerical2 = nn.Linear(32, 128 , bias = False)
        self.linear_numerical3 = nn.Linear(128, 32 , bias = False)
        self.linear_numerical4 = nn.Linear(32, self.num_numerical_var , bias = False)

        self.relu = nn.ReLU()

        
    def masking_table(self, x, seed=42, training = True):
        """
        x: (batch_size, num_var = 19)
        """
        # Set random seed for reproducibility
        
        if training:
            seed = torch.randint(0, 5, (1,))
            torch.manual_seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)
        else:
            torch.manual_seed(42)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(42)

        self.masking_prob = self.mask_prob 
        device = x.device  # Get the device from input tensor

        # Category masking
        category_var = x[:, :self.num_category_var].long()  # Ensure category_var is integer type
        random_cat = torch.rand_like(category_var, dtype=torch.float, device=device)
        masking_cat = random_cat < self.masking_prob
        mask_token = torch.tensor([2, 4, 5, 2, 2, 4, 3], device=device, dtype=torch.long).expand_as(category_var)

        # Apply mask in-place (avoiding memory allocation overhead)
        masked_category_var = category_var.clone()  # Clone to avoid modifying input
        masked_category_var[masking_cat] = mask_token[masking_cat]

        # Numerical masking
        numerical_var = x[:, -self.num_numerical_var:].float()  # Ensure numerical_var is float type
        random_numerical = torch.rand_like(numerical_var, dtype=torch.float, device=device)
        masking_numerical = random_numerical < self.masking_prob

        masked_numerical_var = numerical_var.clone()  # Clone to avoid modifying input
        masked_numerical_var[masking_numerical] = 0 # Directly set masked values to zero

        masking_position = {
            'masking_category': masking_cat,
            'masking_numerical': masking_numerical,
        }


        # Concatenating the masked category and numerical variables
        return torch.cat([masked_category_var, masked_numerical_var], dim=1), masking_position
        # return torch.cat([category_var, masked_numerical_var], dim=1), masking_position

                        
    def forward(self, x, training):
        ''' masking'''
        x, masking_position = self.masking_table(x, training)

        ''' Encoding ''' 
        # category vars
        cat_vars = []
        for c_id, encode_fn in zip(range(self.num_category_var), self.encoders):
            emb_c = encode_fn(x[:,c_id].long())
            emb_c = self.encode_dropout(emb_c)
            cat_vars.append(emb_c)
        cat_vars = torch.stack(cat_vars, dim = 1).float() 
                
        # numerical vars
        num_vars = x[:, - self.num_numerical_var:].float()
        num_emb = []
        for n_id, encode_fn in zip(range(self.num_numerical_var), self.encodes_numerical):
            num_var = num_vars[:, n_id].view(-1, 1)
            emb_n = encode_fn(num_var)
            emb_c = self.encode_dropout(emb_c)
            num_emb.append(emb_n)
        num_emb = torch.stack(num_emb, dim = 1).float() 

        # combine category and numerical vars        
        x = torch.cat([cat_vars, num_emb], dim = 1)

        '''
        Transformer
        '''
        x = self.gpt(x)
        
        ''' Decode category ''' 
        # split numerical and category
        num_vars = x[:, - self.num_numerical_var:]
        cat_vars = x[:, :self.num_category_var]
            
        # category vars
        decoded_cat_vars = []
        for c_id, decode_fn in zip(range(self.num_category_var), self.decoders):
            emb_c = cat_vars[:, c_id, :]
            c_var = decode_fn(emb_c)
            c_var = torch.softmax(c_var, dim = -1)
            c_var = self.decode_dropout(c_var)
            decoded_cat_vars.append(c_var)
        
        # numerical vars
        decoded_num_vars = []
        for n_id, decode_fn in zip(range(self.num_numerical_var), self.decodes_numerical):
            decode_n = decode_fn(num_vars[:, n_id])
            decode_n = self.decode_dropout(decode_n)
            decoded_num_vars.append(decode_n)
        decoded_num_vars = torch.cat(decoded_num_vars, dim = 1)

        decoded_num_vars_ori = decoded_num_vars
        decoded_num_vars = self.relu(self.linear_numerical1(decoded_num_vars))
        decoded_num_vars = self.relu(self.linear_numerical2(decoded_num_vars))
        decoded_num_vars = self.decode_dropout(decoded_num_vars)
        decoded_num_vars = self.relu(self.linear_numerical3(decoded_num_vars))
        decoded_num_vars = self.decode_dropout(decoded_num_vars)
        decoded_num_vars = self.linear_numerical4(decoded_num_vars)
        decoded_num_vars = decoded_num_vars + decoded_num_vars_ori

        return decoded_num_vars, decoded_cat_vars, masking_position

In [40]:

train_dataset = TableDataset(processed_df_train)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=4, shuffle=True)

valid_dataset = TableDataset(processed_df_valid)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, num_workers=4, shuffle=True)

In [41]:
def linear_warmup_decay_lr(lr_init, lr_final, num_warmup_steps, num_training_steps):
    """
    Returns a lambda function for LambdaLR.
    - lr_init: 初始學習率
    - lr_final: 最終學習率（不是 0）
    - num_warmup_steps: 預熱步數
    - num_training_steps: 總訓練步數
    """
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return current_step / num_warmup_steps  # 線性預熱
        else:
            progress = (current_step - num_warmup_steps) / (num_training_steps - num_warmup_steps)
            return (1 - progress) * (1 - lr_final / lr_init) + (lr_final / lr_init)  # 線性衰減到 lr_final
    return lr_lambda

In [42]:
LEARNING_RATE = 2e-3
EPOCHS = 5000


model = tableModel(category_var_len).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.95), eps=1e-6, weight_decay=5e-1)

scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_decay_lr(lr_init = LEARNING_RATE, lr_final = LEARNING_RATE * 1e-2, num_warmup_steps = 100, num_training_steps = EPOCHS))


In [43]:
MSE_loss_fn = nn.MSELoss(reduction='none')  # 逐元素 MSE
CE_loss_fn = nn.CrossEntropyLoss(reduction='none')  # 逐元素 CrossEntropy

def loss_fn(pred_numerical, pred_category, label, mask_position):
    device = label.device
    
    num_numerical = 12
    num_category = 7
    ratio_numerical = num_numerical / (num_numerical + num_category)
    ratio_category = 1 / (num_numerical + num_category)

    masking_category = mask_position['masking_category']  # shape: (batch_size, num_category)
    masking_numerical = mask_position['masking_numerical']  # shape: (batch_size, num_numerical)

    label_category = label[:, :num_category]
    label_numerical = label[:, -num_numerical:]

    total_loss = torch.zeros(1, device=device)

    # === 1. MSE Loss ===
    # 先用 masking_numerical 過濾 pred_numerical 和 label_numerical
    pred_numerical_masked = pred_numerical[masking_numerical]
    label_numerical_masked = label_numerical[masking_numerical]

    if pred_numerical_masked.numel() > 0:  # 確保有 mask 位置
        mse_loss = MSE_loss_fn(pred_numerical_masked, label_numerical_masked).mean()
        total_loss += (mse_loss * 3) * ratio_numerical

    # === 2. CrossEntropy Loss ===
    for i in range(num_category):
        category_mask = masking_category[:, i]  # shape: (batch_size,)
        pred = pred_category[i]  # shape: (batch_size, num_classes)
        label_cat = label_category[:, i].long()  # shape: (batch_size,)

        # 先用 mask 過濾 pred 和 label
        pred_masked = pred[category_mask]
        label_masked = label_cat[category_mask]
        
        if pred_masked.shape[0] > 0:  # 確保有 mask 位置
            # print(f"pred_masked: {pred_masked.shape}")
            # print(f"label_masked: {label_masked.shape}")
            # raise
            ce_loss = CE_loss_fn(pred_masked, label_masked).mean()
            total_loss += ce_loss * ratio_category

    ce_loss = (total_loss - mse_loss)

    return total_loss, mse_loss, ((total_loss - mse_loss))


In [44]:

train_LOSS = []
valid_LOSS = []

for epoch in tqdm(range(EPOCHS), desc="iterate epoch"):
    losses = []
    mse_losses = []
    ce_losses = []
    
    val_losses = []
    val_mse_losses = []
    val_ce_losses = []
    
    
    model.train()
    for data in train_dataloader:
        data = data.float().to(device)

        pred_numerical, pred_category, masking_position = model(data, training = True)
        loss, mse_loss, ce_loss = loss_fn(pred_numerical, pred_category, data, masking_position)
        losses.append(loss.item())
        mse_losses.append(mse_loss.item())
        ce_losses.append(ce_loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    scheduler.step()
    
    losses = np.mean(losses)
    mse_losses = np.mean(mse_losses)
    ce_losses = np.mean(ce_losses)

    train_LOSS.append(losses)
    
    if epoch % 100 == 0:    
        print(f"epoch: {epoch}, loss: {losses}, mse: {mse_losses}, ce: {ce_losses}")
    
    with torch.no_grad():
        model.eval()
        for data in valid_dataloader:
            data = data.float().to(device)
                
            pred_numerical, pred_category, masking_position = model(data, training = False)
            loss, mse_loss, ce_loss = loss_fn(pred_numerical, pred_category, data, masking_position)
            val_losses.append(loss.item())
            val_mse_losses.append(mse_loss.item())
            val_ce_losses.append(ce_loss.item())
            
    val_losses = np.mean(val_losses)
    val_mse_losses = np.mean(val_mse_losses)
    val_ce_losses = np.mean(val_ce_losses)
    
    valid_LOSS.append(val_losses)
    
    if epoch % 100 == 0:
        print(f"epoch: {epoch}, val_loss: {val_losses}, val_mse: {val_mse_losses}, val_ce_losses: {val_ce_losses}")
        print()
        
        
        
        # pred = model.inference(data)
        
        
    

iterate epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

epoch: 0, loss: 1.674682930111885, mse: 0.6244314983487129, ce: 1.0502514243125916


iterate epoch:   0%|          | 1/5000 [00:00<1:20:34,  1.03it/s]

epoch: 0, val_loss: 1.2914193272590637, val_mse: 0.444048672914505, val_ce_losses: 0.8473706543445587



iterate epoch:   2%|▏         | 100/5000 [01:45<1:26:36,  1.06s/it]

epoch: 100, loss: 2.3213697522878647, mse: 1.0031017623841763, ce: 1.3182679936289787


iterate epoch:   2%|▏         | 101/5000 [01:46<1:28:07,  1.08s/it]

epoch: 100, val_loss: 2.8029305934906006, val_mse: 1.2711272835731506, val_ce_losses: 1.53180330991745



iterate epoch:   4%|▍         | 200/5000 [03:32<1:25:28,  1.07s/it]

epoch: 200, loss: 0.4757934585213661, mse: 0.055387381464242935, ce: 0.4204060770571232


iterate epoch:   4%|▍         | 201/5000 [03:34<1:25:33,  1.07s/it]

epoch: 200, val_loss: 0.7004018127918243, val_mse: 0.15782703645527363, val_ce_losses: 0.5425747707486153



iterate epoch:   6%|▌         | 300/5000 [05:19<1:22:10,  1.05s/it]

epoch: 300, loss: 0.34298280254006386, mse: 0.009488878247793764, ce: 0.33349392376840115


iterate epoch:   6%|▌         | 301/5000 [05:20<1:19:35,  1.02s/it]

epoch: 300, val_loss: 0.4800809174776077, val_mse: 0.04225102625787258, val_ce_losses: 0.43782988935709



iterate epoch:   8%|▊         | 400/5000 [07:05<1:24:18,  1.10s/it]

epoch: 400, loss: 0.3320746775716543, mse: 0.010361134947743267, ce: 0.32171354070305824


iterate epoch:   8%|▊         | 401/5000 [07:06<1:24:45,  1.11s/it]

epoch: 400, val_loss: 0.4493557848036289, val_mse: 0.025397315621376038, val_ce_losses: 0.4239584803581238



iterate epoch:  10%|█         | 500/5000 [08:52<1:13:50,  1.02it/s]

epoch: 500, loss: 0.3148099761456251, mse: 0.007136636122595519, ce: 0.3076733388006687


iterate epoch:  10%|█         | 501/5000 [08:53<1:14:23,  1.01it/s]

epoch: 500, val_loss: 0.5045193880796432, val_mse: 0.04884209576994181, val_ce_losses: 0.4556773006916046



iterate epoch:  12%|█▏        | 600/5000 [10:40<1:18:21,  1.07s/it]

epoch: 600, loss: 0.3056306689977646, mse: 0.004770202998770401, ce: 0.3008604682981968


iterate epoch:  12%|█▏        | 601/5000 [10:41<1:17:55,  1.06s/it]

epoch: 600, val_loss: 0.4848494827747345, val_mse: 0.03638424398377538, val_ce_losses: 0.4484652280807495



iterate epoch:  14%|█▍        | 700/5000 [12:21<1:16:56,  1.07s/it]

epoch: 700, loss: 0.3116259425878525, mse: 0.008747968066018075, ce: 0.30287797562777996


iterate epoch:  14%|█▍        | 701/5000 [12:23<1:17:41,  1.08s/it]

epoch: 700, val_loss: 0.481229692697525, val_mse: 0.03567779203876853, val_ce_losses: 0.44555190950632095



iterate epoch:  16%|█▌        | 800/5000 [14:07<1:13:55,  1.06s/it]

epoch: 800, loss: 0.3039394151419401, mse: 0.005986248317640275, ce: 0.29795316606760025


iterate epoch:  16%|█▌        | 801/5000 [14:08<1:15:09,  1.07s/it]

epoch: 800, val_loss: 0.4702528491616249, val_mse: 0.036556096747517586, val_ce_losses: 0.4336967468261719



iterate epoch:  18%|█▊        | 900/5000 [15:52<1:11:04,  1.04s/it]

epoch: 900, loss: 0.3033582028001547, mse: 0.00372694552061148, ce: 0.2996312566101551


iterate epoch:  18%|█▊        | 901/5000 [15:53<1:10:33,  1.03s/it]

epoch: 900, val_loss: 0.4818481430411339, val_mse: 0.030217438470572233, val_ce_losses: 0.45163069665431976



iterate epoch:  20%|██        | 1000/5000 [17:38<1:10:43,  1.06s/it]

epoch: 1000, loss: 0.3068159241229296, mse: 0.008642581000458449, ce: 0.29817334190011024


iterate epoch:  20%|██        | 1001/5000 [17:39<1:09:45,  1.05s/it]

epoch: 1000, val_loss: 0.4574701189994812, val_mse: 0.03220792347565293, val_ce_losses: 0.425262201577425



iterate epoch:  22%|██▏       | 1100/5000 [19:23<1:06:39,  1.03s/it]

epoch: 1100, loss: 0.3028158098459244, mse: 0.005120024703501258, ce: 0.29769578762352467


iterate epoch:  22%|██▏       | 1101/5000 [19:24<1:06:12,  1.02s/it]

epoch: 1100, val_loss: 0.5390864461660385, val_mse: 0.05905591230839491, val_ce_losses: 0.48003053665161133



iterate epoch:  24%|██▍       | 1200/5000 [21:08<1:02:55,  1.01it/s]

epoch: 1200, loss: 0.29885804653167725, mse: 0.0051946941239293665, ce: 0.29366335272789


iterate epoch:  24%|██▍       | 1201/5000 [21:09<1:04:17,  1.02s/it]

epoch: 1200, val_loss: 0.4484665170311928, val_mse: 0.02353182714432478, val_ce_losses: 0.424934696406126



iterate epoch:  26%|██▌       | 1300/5000 [22:57<1:08:37,  1.11s/it]

epoch: 1300, loss: 0.299773383885622, mse: 0.006500284740468487, ce: 0.29327309876680374


iterate epoch:  26%|██▌       | 1301/5000 [22:58<1:08:18,  1.11s/it]

epoch: 1300, val_loss: 0.45818449556827545, val_mse: 0.03422137862071395, val_ce_losses: 0.42396312206983566



iterate epoch:  28%|██▊       | 1397/5000 [24:41<1:03:41,  1.06s/it]


KeyboardInterrupt: 

In [None]:


plt.plot(range(len(train_LOSS)), train_LOSS, color = 'blue')
plt.plot(range(len(valid_LOSS)), valid_LOSS, color = 'red')

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
from PIL import Image

img = Image.open('../../dataset/train/0.png')

In [None]:
img_np = np.array(img)

In [None]:
img_np.shape

In [None]:
a = torch.rand((128,128))
a.shape

In [None]:
128 * 128

In [None]:
mean = torch.mean(a)
std = torch.std(a)

outlier_upper = mean + 1 * std
outlier_down = mean - 1 * std

(a < outlier_upper) & (a > outlier_down)

In [None]:
k = a[(a < outlier_upper) & (a > outlier_down)]
k

In [None]:
k = a[(a < outlier_upper) & (a > outlier_down)]
k

In [None]:
torch.sum(((a < outlier_upper) & (a > outlier_down)))

In [None]:
mask = torch.ones((8,8))
causal_mask = torch.tril(mask)
# causal_mask[:8, :8] = float('-inf')
causal_mask

In [None]:
causal_mask = torch.where(causal_mask == 0, float('-inf'), causal_mask)

In [None]:
causal_mask

In [None]:
c = torch.rand((64, 8,8))
c.shape

In [None]:
d = c * causal_mask
d[0]

In [None]:
e = torch.softmax(d, dim=-1)
e

In [None]:
0.04

In [None]:
torch.arange(2)

In [None]:
nn.Parameter(torch.rand())