In [None]:
import os
import pandas as pd
from tqdm import tqdm
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from datetime import datetime

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification
from torch.optim import AdamW

from sklearn.preprocessing import LabelEncoder
from sklearn.decomposition import PCA

In [None]:
run_name = f"NoteEmbedding_{datetime.now().strftime('%m%d_%H%M%S')}"

In [None]:
matplotlib.rcParams['font.sans-serif'] = ['Noto Sans CJK JP']
matplotlib.rcParams['axes.unicode_minus'] = False

In [None]:
class NoteDataset(Dataset):
    def __init__(self, notes, labels, tokenizer, max_len=64):
        self.notes = notes
        self.lebels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.notes)

    def __getitem__(self, idx):
        note = self.notes[idx]
        label = self.lebels[idx]
        encoding = self.tokenizer(
            note,
            truncation=True,
            padding='max_length',
            max_length=self.max_len,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

In [None]:
df = pd.read_csv("data/note_class_map.csv")
notes = df['note'].tolist()
le = LabelEncoder()
labels = le.fit_transform(df['class'])
print(f"Classes: {le.classes_}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
model = BertForSequenceClassification.from_pretrained("bert-base-chinese", num_labels=len(le.classes_)).to(device)
optimizer = AdamW(model.parameters(), lr=2e-5)

In [None]:
dataset = NoteDataset(notes, labels, tokenizer, max_len=64)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [None]:
def train_one_epoch():
    model.train()
    for batch in tqdm(dataloader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    return loss.item()

In [None]:
def evaluate(epoch):
    model.eval()
    embeddings = []
    with torch.no_grad():
        for note in tqdm(notes, desc="Evaluating"):
            encoding = tokenizer(
                note, 
                truncation=True, 
                padding='max_length', 
                max_length=64, 
                return_tensors='pt'
            )
            input_ids = encoding['input_ids'].to(device)
            attention_mask = encoding['attention_mask'].to(device)
            outputs = model.bert(input_ids=input_ids, attention_mask=attention_mask)
            cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
            embeddings.append(cls_embedding)

    embeddings = np.array(embeddings)

    # PCA
    pca = PCA(n_components=2)
    embeddings_2d = pca.fit_transform(embeddings)
    
    os.makedirs(f"PCA/{run_name}", exist_ok=True)
    plt.figure(figsize=(10,8))
    scatter = plt.scatter(
        embeddings_2d[:,0], 
        embeddings_2d[:,1], 
        c=labels, 
        cmap='tab10', 
        alpha=0.7
    )
    plt.title(f"Epoch {epoch} - PCA")
    plt.xlabel('PC1')
    plt.ylabel('PC2')
    plt.legend(
        handles=scatter.legend_elements()[0],
        labels=[str(x) for x in le.classes_[:len(scatter.legend_elements()[0])]],
        title='Class',
        bbox_to_anchor=(1.05, 1),
        loc='upper left'
    )
    plt.tight_layout()
    plt.savefig(f"PCA/{run_name}/{epoch}.png")
    plt.close()

In [None]:
def save_model(epoch):
    os.makedirs(f"models/{run_name}", exist_ok=True)
    torch.save(model.state_dict(), f"models/{run_name}/bert_epoch_{epoch}.pth")

In [None]:
epochs = 10
for ep in range(1,epochs+1):
    print(f"====== Epoch {ep} ======")
    
    # Train
    loss = train_one_epoch()
    print(f"Loss: {loss}")
    
    # Evaluate
    evaluate(ep)
    
    # Save
    save_model(ep)