In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"

In [2]:
import random
import tqdm
import math
import torch
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence

In [None]:
class cfg:
    embedding_size = 32
    hidden_size = 128
    vocab_size = 101
    ouput_size = vocab_size
    max_length = 15
    
    test_round = 1000
    evaluate_round = 10

    batch_size = 32
    steps = int(15e3)
    lr = 5e-4
    
    device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
    print(f'device:{device}')

    padding_value = 0
    eos = 100

device:cpu


In [4]:
# determine if is prime
def is_prime(num):
    if num < 2:
        return False
    for i in range(2, int(num**0.5) + 1):
        if num % i == 0:
            return False
    return True

# generate prime
def generate_data(num_samples, max_length=cfg.max_length):
    input_data = []
    output_data = []
    for _ in range(num_samples):
        # random generate
        input_seq = [random.randint(0, cfg.vocab_size-2) for _ in range(random.randint(1, max_length))]
        # build output
        output_seq = [x for x in sorted([x for x in input_seq if is_prime(x)], reverse=True)]+[cfg.eos]
        #data.append((",".join(map(str, input_seq)), ",".join(output_seq)))
        input_data.append(input_seq)
        output_data.append(output_seq)
    return input_data, output_data

def generate_data(num_samples, reverse, max_length=cfg.max_length):
    input_data = []
    output_data = []
    for _ in range(num_samples):
        # random generate
        input_seq = [random.randint(1, 100) for _ in range(random.randint(1, max_length))]
        # build target
        output_seq = [x for x in input_seq if is_prime(x)] #[::-1]
        output_seq.append(cfg.eos)
        if reverse:
            input_seq = input_seq[::-1]
        else:
            pass
        input_data.append(input_seq)
        output_data.append(output_seq)
    return input_data, output_data


# padding tensor to the max_length
def list_2_tensor(data):
    tensor_list = [torch.tensor(sublist, dtype=torch.long, device=cfg.device) for sublist in data]
    padded_tensor = pad_sequence(tensor_list, batch_first=True, padding_value=cfg.padding_value)

    return padded_tensor

class Seq_1(nn.Module):
    '''
    input are batchs of seqs, seqs have the undetermined length
    output is the final hidden_state
    '''
    def __init__(self):
        super(Seq_1, self).__init__()
        #self.xh = nn.Linear(cfg.embedding_size, cfg.hidden_size)
        self.xh = nn.Sequential(
            nn.Linear(cfg.embedding_size, cfg.hidden_size*2),
            nn.ReLU(),
            nn.Linear(cfg.hidden_size*2, cfg.hidden_size//2),
            nn.ReLU(),
            nn.Linear(cfg.hidden_size//2, cfg.hidden_size)
        )
        #self.hh = nn.Linear(cfg.hidden_size, cfg.hidden_size)
        self.hh = nn.Sequential(
            nn.Linear(cfg.hidden_size, cfg.hidden_size*2),
            nn.ReLU(),
            nn.Linear(cfg.hidden_size*2, cfg.hidden_size//2),
            nn.ReLU(),
            nn.Linear(cfg.hidden_size//2, cfg.hidden_size)
        )
        self.sigmoid = nn.Sigmoid() # Sigmoid is a class, inherit from nn.Module
        self.tanh = nn.Tanh()

    def forward(self, seq, input_lengths):
        batch_size, seq_len, embedding_size = seq.size()
        mask = torch.arange(seq_len, device=cfg.device).expand(batch_size, -1) < input_lengths.unsqueeze(1) #shape become [batch_size, 1] '1' can be used in any compare, no matter the shape of another tensor (auto broadcast)

        # hidden state should correspond to the batch_size, is not the parameter, is more like the temperate variable
        hidden_state = torch.zeros(batch_size, cfg.hidden_size, device=cfg.device) # needn't expand into two dim, we can just simply broadcast it
        #hidden_state = torch.zeros(cfg.hidden_size, device=cfg.device)
        #tokens = torch.unbind(seq, dim=1) # unbind the seq form the dim 1

        for t in range(seq_len):
            token = seq[:,t,:] # unbind from the dim of tokens
            current_mask = mask[:, t].unsqueeze(1)
            #temp_hidden_state = self.sigmoid(self.xh(token)+self.hh(hidden_state))
            temp_hidden_state = self.tanh(self.xh(token)+self.hh(hidden_state))
            hidden_state = torch.where(current_mask, temp_hidden_state, hidden_state)
    
        return hidden_state

class Seq_2(nn.Module):
    '''
    input: hidden state from seq_1
    output result tokens
    '''
    def __init__(self):
        super(Seq_2, self).__init__()
        self.hh = nn.Sequential(
            nn.Linear(cfg.hidden_size, cfg.hidden_size*2),
            nn.ReLU(),
            nn.Linear(cfg.hidden_size*2, cfg.hidden_size//2),
            nn.ReLU(),
            nn.Linear(cfg.hidden_size//2, cfg.hidden_size),
            nn.ReLU()
        )
        self.hv = nn.Sequential(
            nn.Linear(cfg.hidden_size, cfg.hidden_size*2),
            nn.ReLU(),
            nn.Linear(cfg.hidden_size*2, cfg.hidden_size//2),
            nn.ReLU(),
            nn.Linear(cfg.hidden_size//2, cfg.vocab_size)
        )
        self.sigmoid = nn.Sigmoid() # Sigmoid is a class, inherit from nn.Module

    def forward(self, hidden_state, decode_length):
        batch_size, _ = hidden_state.size()
        outputs = torch.zeros(batch_size, decode_length, cfg.vocab_size, device=cfg.device) # this tensor have a continual ram space, use it to avoid using torch.cat(), which cause O(n^2) complexity
        
        for t in range(decode_length):
            hidden_state = self.hh(hidden_state)
            outputs[:,t,:] = self.hv(hidden_state) # just output logits, loss function will handle softmax
    
        return outputs

In [5]:
def lcs_length(a, b):
    m, n = len(a), len(b)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if a[i - 1] == b[j - 1]:
                dp[i][j] = dp[i - 1][j - 1] + 1
            else:
                dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
    return dp[m][n]

def calculate_metrics(target, pred, pad_token=100):
    # get the legal length
    target_valid_len = target.index(pad_token) if pad_token in target else len(target)
    target_valid = target[:target_valid_len]
    
    # get the legal length 
    if pad_token in pred:
        pred_valid_len = pred.index(pad_token)
        pred_valid = pred[:pred_valid_len]
    else:
        pred_valid_len = len(pred)
        pred_valid = pred.copy()
    
    aligned_pred = []
    for i in range(target_valid_len):
        if i < len(pred_valid):
            aligned_pred.append(pred_valid[i])
        else:
            aligned_pred.append(pad_token)  # 填充

    match_count = sum(1 for t, p in zip(target_valid, aligned_pred) if t == p)
    psa = match_count / target_valid_len if target_valid_len > 0 else 0.0

    lcs = lcs_length(target_valid, aligned_pred)
    lcsr = lcs / target_valid_len if target_valid_len > 0 else 0.0

    geo_mean = (psa * lcsr) ** 0.5 if psa > 0 and lcsr > 0 else 0.0
    
    return psa, lcsr, geo_mean

In [None]:
def train_and_evaluate(reverse):
    seq_1 = Seq_1().to(cfg.device)
    seq_2 = Seq_2().to(cfg.device)
    embedding = nn.Embedding(cfg.vocab_size, cfg.embedding_size, device=cfg.device)
    
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    optimizer = optim.Adam(
        list(seq_1.parameters()) + list(seq_2.parameters()) + list(embedding.parameters()),
        lr=cfg.lr
    )
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.1)

    seq_1.train()
    seq_2.train()
    embedding.train()
    
    for step in tqdm.trange(cfg.steps, desc="Training"):
        data = generate_data(cfg.batch_size, reverse)
        input_lengths = torch.tensor([len(seq) for seq in data[0]], device=cfg.device)
        batch_input = list_2_tensor(data[0])
        batch_target = list_2_tensor(data[1])
        
        embedded = embedding(batch_input)
        hidden = seq_1(embedded, input_lengths)
        outputs = seq_2(hidden, batch_target.size(1))
        
        loss = criterion(outputs.view(-1, cfg.vocab_size), batch_target.view(-1))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

    seq_1.eval()
    seq_2.eval()
    embedding.eval()
    
    total_psa, total_lcsr, total_geo, num_samples = 0.0, 0.0, 0.0, 0
    
    with torch.no_grad():
        for _ in range(cfg.test_round):
            data = generate_data(cfg.batch_size, reverse)
            batch_input = list_2_tensor(data[0])
            batch_target = list_2_tensor(data[1])

            embedded = embedding(batch_input)
            hidden = seq_1(embedded, torch.tensor([len(seq) for seq in data[0]], device=cfg.device))
            outputs = seq_2(hidden, batch_target.size(1))
            preds = outputs.argmax(-1).cpu().tolist()

            for idx in range(cfg.batch_size):
                target_seq = data[1][idx]
                pred_seq = preds[idx]
                
                psa, lcsr, geo = calculate_metrics(target_seq, pred_seq)
                total_psa += psa
                total_lcsr += lcsr
                total_geo += geo
                num_samples += 1

    return (
        total_psa / num_samples,
        total_lcsr / num_samples,
        total_geo / num_samples
    )

def main(reverse):
    all_psa, all_lcsr, all_geo = [], [], []
    
    for round in range(cfg.evaluate_round):
        print(f"------ Round {round+1}/{cfg.evaluate_round} ------")
        psa, lcsr, geo = train_and_evaluate(reverse)
        
        all_psa.append(psa)
        all_lcsr.append(lcsr)
        all_geo.append(geo)
        
        print(f"[Round {round+1}] PSA: {psa:.4f}, LCSR: {lcsr:.4f}, GeoMean: {geo:.4f}")

    final_psa = sum(all_psa) / len(all_psa)
    final_lcsr = sum(all_lcsr) / len(all_lcsr)
    final_geo = sum(all_geo) / len(all_geo)
    
    print("\n=== Final Average Metrics ===")
    print(f"PSA: {final_psa:.4f}")
    print(f"LCSR: {final_lcsr:.4f}")
    print(f"GeoMean: {final_geo:.4f}")

In [7]:
main(True)

------ Round 1/10 ------


Training: 100%|████████████████████████████████████████████████████████████████████████████| 15000/15000 [01:51<00:00, 134.71it/s]


[Round 1] PSA: 0.0476, LCSR: 0.0622, GeoMean: 0.0482
------ Round 2/10 ------


Training: 100%|████████████████████████████████████████████████████████████████████████████| 15000/15000 [01:50<00:00, 135.82it/s]


[Round 2] PSA: 0.2761, LCSR: 0.3000, GeoMean: 0.2785
------ Round 3/10 ------


Training: 100%|████████████████████████████████████████████████████████████████████████████| 15000/15000 [01:51<00:00, 134.65it/s]


[Round 3] PSA: 0.0321, LCSR: 0.0655, GeoMean: 0.0327
------ Round 4/10 ------


Training: 100%|████████████████████████████████████████████████████████████████████████████| 15000/15000 [01:52<00:00, 132.98it/s]


[Round 4] PSA: 0.3231, LCSR: 0.3438, GeoMean: 0.3263
------ Round 5/10 ------


Training: 100%|████████████████████████████████████████████████████████████████████████████| 15000/15000 [01:51<00:00, 134.02it/s]


[Round 5] PSA: 0.0316, LCSR: 0.0635, GeoMean: 0.0323
------ Round 6/10 ------


Training: 100%|████████████████████████████████████████████████████████████████████████████| 15000/15000 [01:51<00:00, 133.93it/s]


[Round 6] PSA: 0.0809, LCSR: 0.1041, GeoMean: 0.0814
------ Round 7/10 ------


Training: 100%|████████████████████████████████████████████████████████████████████████████| 15000/15000 [01:51<00:00, 134.26it/s]


[Round 7] PSA: 0.0326, LCSR: 0.0705, GeoMean: 0.0334
------ Round 8/10 ------


Training: 100%|████████████████████████████████████████████████████████████████████████████| 15000/15000 [01:51<00:00, 134.30it/s]


[Round 8] PSA: 0.2516, LCSR: 0.2715, GeoMean: 0.2533
------ Round 9/10 ------


Training: 100%|████████████████████████████████████████████████████████████████████████████| 15000/15000 [01:51<00:00, 134.39it/s]


[Round 9] PSA: 0.2158, LCSR: 0.2363, GeoMean: 0.2174
------ Round 10/10 ------


Training: 100%|████████████████████████████████████████████████████████████████████████████| 15000/15000 [01:51<00:00, 134.41it/s]


[Round 10] PSA: 0.2154, LCSR: 0.2414, GeoMean: 0.2169

=== Final Average Metrics ===
PSA: 0.1507
LCSR: 0.1759
GeoMean: 0.1520


In [8]:
main(False)

------ Round 1/10 ------


Training: 100%|████████████████████████████████████████████████████████████████████████████| 15000/15000 [01:51<00:00, 133.95it/s]


[Round 1] PSA: 0.1646, LCSR: 0.1862, GeoMean: 0.1653
------ Round 2/10 ------


Training: 100%|████████████████████████████████████████████████████████████████████████████| 15000/15000 [01:54<00:00, 130.59it/s]


[Round 2] PSA: 0.0701, LCSR: 0.0924, GeoMean: 0.0705
------ Round 3/10 ------


Training: 100%|████████████████████████████████████████████████████████████████████████████| 15000/15000 [01:51<00:00, 134.81it/s]


[Round 3] PSA: 0.2036, LCSR: 0.2250, GeoMean: 0.2044
------ Round 4/10 ------


Training: 100%|████████████████████████████████████████████████████████████████████████████| 15000/15000 [01:52<00:00, 133.84it/s]


[Round 4] PSA: 0.1281, LCSR: 0.1516, GeoMean: 0.1288
------ Round 5/10 ------


Training: 100%|████████████████████████████████████████████████████████████████████████████| 15000/15000 [01:55<00:00, 129.80it/s]


[Round 5] PSA: 0.0322, LCSR: 0.0646, GeoMean: 0.0326
------ Round 6/10 ------


Training: 100%|████████████████████████████████████████████████████████████████████████████| 15000/15000 [01:52<00:00, 133.15it/s]


[Round 6] PSA: 0.0478, LCSR: 0.0733, GeoMean: 0.0483
------ Round 7/10 ------


Training: 100%|████████████████████████████████████████████████████████████████████████████| 15000/15000 [01:51<00:00, 134.05it/s]


[Round 7] PSA: 0.0331, LCSR: 0.0657, GeoMean: 0.0337
------ Round 8/10 ------


Training: 100%|████████████████████████████████████████████████████████████████████████████| 15000/15000 [01:51<00:00, 134.29it/s]


[Round 8] PSA: 0.1777, LCSR: 0.2027, GeoMean: 0.1785
------ Round 9/10 ------


Training: 100%|████████████████████████████████████████████████████████████████████████████| 15000/15000 [01:51<00:00, 134.63it/s]


[Round 9] PSA: 0.1898, LCSR: 0.2135, GeoMean: 0.1908
------ Round 10/10 ------


Training: 100%|████████████████████████████████████████████████████████████████████████████| 15000/15000 [01:51<00:00, 134.95it/s]


[Round 10] PSA: 0.0368, LCSR: 0.0639, GeoMean: 0.0374

=== Final Average Metrics ===
PSA: 0.1084
LCSR: 0.1339
GeoMean: 0.1090
