In [1]:
import random
import tqdm
import math
import tools
import Model
import torch
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence

In [2]:
class Luong_AttentionTrainer:
    def __init__(self, config):
        self.config = config
        self._init_model()

    def _init_model(self):
        self.encoder = self.config.encoder(self.config).to(self.config.device)
        self.decoder = self.config.decoder(self.config).to(self.config.device)
        self.embedding = nn.Embedding(self.config.vocab_size, self.config.embedding_size, device=self.config.device)

        self.criterion = nn.CrossEntropyLoss(ignore_index=self.config.padding_value)
        self.optimizer = optim.Adam(
            list(self.encoder.parameters()) +
            list(self.decoder.parameters()) +
            list(self.embedding.parameters()),
            lr=self.config.lr
        )
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=5000, gamma=0.1)
        
    @staticmethod
    def is_prime(num):
        if num < 2:
            return False
        for i in range(2, int(num**0.5) + 1):
            if num % i == 0:
                return False
        return True
    
    def generate_data(self, num_samples, reverse):
        input_data = []
        output_data = []
        for _ in range(num_samples):
            input_seq = [random.randint(1, 100) for _ in range(random.randint(1, self.config.max_length))]
            output_seq = [x for x in input_seq if self.is_prime(x)]
            output_seq.append(self.config.eos)
            if reverse:
                input_seq = input_seq[::-1]
            input_data.append(input_seq)
            output_data.append(output_seq)
        return input_data, output_data
    
    def list_2_tensor(self, data):
        tensor_list = [torch.tensor(sublist, dtype=torch.long, device=self.config.device) for sublist in data]
        padded_tensor = pad_sequence(tensor_list, batch_first=True, padding_value=self.config.padding_value)
        return padded_tensor
    
    @staticmethod
    def lcs_length(a, b):
        m, n = len(a), len(b)
        dp = [[0] * (n + 1) for _ in range(m + 1)]
        for i in range(1, m + 1):
            for j in range(1, n + 1):
                if a[i - 1] == b[j - 1]:
                    dp[i][j] = dp[i - 1][j - 1] + 1
                else:
                    dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
        return dp[m][n]
    
    def calculate_metrics(self, target, pred):
        pad_token = self.config.eos
        target_valid_len = target.index(pad_token) if pad_token in target else len(target)
        target_valid = target[:target_valid_len]
        
        if pad_token in pred:
            pred_valid_len = pred.index(pad_token)
            pred_valid = pred[:pred_valid_len]
        else:
            pred_valid_len = len(pred)
            pred_valid = pred.copy()
        
        aligned_pred = []
        for i in range(target_valid_len):
            if i < len(pred_valid):
                aligned_pred.append(pred_valid[i])
            else:
                aligned_pred.append(pad_token)

        match_count = sum(1 for t, p in zip(target_valid, aligned_pred) if t == p)
        psa = match_count / target_valid_len if target_valid_len > 0 else 0.0

        lcs = self.lcs_length(target_valid, aligned_pred)
        lcsr = lcs / target_valid_len if target_valid_len > 0 else 0.0

        geo_mean = (psa * lcsr) ** 0.5 if psa > 0 and lcsr > 0 else 0.0
        
        return psa, lcsr, geo_mean
    
    def train(self, reverse=True):
        self.encoder.train()
        self.decoder.train()
        self.embedding.train()
        
        for step in tqdm.trange(self.config.steps, desc="Training"):
            data = self.generate_data(self.config.batch_size, reverse)
            input_lengths = torch.tensor([len(seq) for seq in data[0]], device=self.config.device)
            batch_input = self.list_2_tensor(data[0])
            batch_target = self.list_2_tensor(data[1])
            
            embedded = self.embedding(batch_input)
            hidden_state_records = self.encoder(embedded, input_lengths)
            outputs = self.decoder(hidden_state_records, batch_target.size(1))
            
            loss = self.criterion(outputs.view(-1, self.config.vocab_size), batch_target.view(-1))
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()
    
    
    def evaluate(self, reverse=True):
        self.encoder.eval()
        self.decoder.eval()
        self.embedding.eval()
        
        total_psa, total_lcsr, total_geo, num_samples = 0.0, 0.0, 0.0, 0
        
        with torch.no_grad():
            for _ in range(self.config.test_round):
                data = self.generate_data(self.config.batch_size, reverse)
                input_lengths = torch.tensor([len(seq) for seq in data[0]], device=self.config.device)
                batch_input = self.list_2_tensor(data[0])
                batch_target = self.list_2_tensor(data[1])
                
                embedded = self.embedding(batch_input)
                hidden_state_records = self.encoder(embedded, input_lengths)
                outputs = self.decoder(hidden_state_records, batch_target.size(1))
                
                preds = outputs.argmax(-1).cpu().tolist()
    
                for idx in range(self.config.batch_size):
                    target_seq = data[1][idx]
                    pred_seq = preds[idx]
                    
                    psa, lcsr, geo = self.calculate_metrics(target_seq, pred_seq)
                    total_psa += psa
                    total_lcsr += lcsr
                    total_geo += geo
                    num_samples += 1
    
        return (
            total_psa / num_samples,
            total_lcsr / num_samples,
            total_geo / num_samples
        )
    
    def train_and_evaluate(self, reverse=True, num_rounds=None):
        if num_rounds is None:
            num_rounds = self.config.evaluate_round
            
        all_psa, all_lcsr, all_geo = [], [], []
        
        for round in range(num_rounds):
            print(f"------ Round {round+1}/{num_rounds} ------")
            self._init_model()
            self.train(reverse)
            psa, lcsr, geo = self.evaluate(reverse)
            
            all_psa.append(psa)
            all_lcsr.append(lcsr)
            all_geo.append(geo)
            
            print(f"[Round {round+1}] PSA: {psa:.4f}, LCSR: {lcsr:.4f}, GeoMean: {geo:.4f}")

        final_psa = sum(all_psa) / len(all_psa)
        final_lcsr = sum(all_lcsr) / len(all_lcsr)
        final_geo = sum(all_geo) / len(all_geo)
        
        print("\n=== Final Average Metrics ===")
        print(f"PSA: {final_psa:.4f}")
        print(f"LCSR: {final_lcsr:.4f}")
        print(f"GeoMean: {final_geo:.4f}")
        
        return final_psa, final_lcsr, final_geo

In [3]:
class TrainingConfig:
    def __init__(self, 
                 encoder,
                 decoder,
                 embedding_size=32,
                 hidden_size=128,
                 vocab_size=101,
                 max_length=10,
                 evaluate_round=3,
                 batch_size=32,
                 steps=5000,
                 lr=5e-4,
                 device=None,
                 padding_value=0,
                 eos=100
                ):

        self.encoder = encoder
        self.decoder = decoder
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.output_size = vocab_size
        self.max_length = max_length
        self.test_round = steps//10
        self.evaluate_round = evaluate_round
        self.batch_size = batch_size
        self.steps = steps
        self.lr = lr
        self.padding_value = padding_value
        self.eos = eos
        
        # 自动检测设备
        if device is None:
            self.device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
        else:
            self.device = device

        self.device = 'cpu'

    def params(self):
        return{
            'encoder.name': self.encoder.name,
            'decoder.name': self.decoder.name,
            'embedding_size': self.embedding_size,
            'hidden_size': self.hidden_size,
            'vocab_size': self.vocab_size,
            'output_size': self.output_size,
            'max_length': self.max_length,
            'test_round': self.test_round,
            'evaluate_round': self.evaluate_round,
            'batch_size': self.batch_size,
            'steps': self.steps,
            'lr': self.lr,
            'padding_value': self.padding_value,
            'eos': self.eos,
        }

In [None]:
encoders = [Model.Luong_Encoder]
decoders = [Model.Luong_Decoder_Dot] #[Model.Attention_Decoder, Model.Attention_Decoder_41, Model.Attention_Decoder_42, Model.Attention_Decoder_43]
max_lengths = [i*50+10 for i in range(7)]

config_params = [{'encoder': encoder, 'decoder': decoder, 'max_length': max_length} for encoder in encoders for decoder in decoders for max_length in max_lengths]

result = []
for config_param in config_params:
    train_config = TrainingConfig(**config_param)
    print(train_config.params())
    trainer = Luong_AttentionTrainer(train_config)
    final_psa, final_lcsr, final_geo = trainer.train_and_evaluate(reverse=True)
    result.append((train_config.params(), {'final_psa': final_psa, 'final_lcsr': final_lcsr, 'final_geo': final_geo}))