In [1]:
! pip install torch



In [None]:
import transformer

In [None]:
import torch
import torch.nn as nn
import math

device = "mps"

class SimpleGPTPredictor(nn.Module):
    """
    
    
    """
    def __init__(self, vocab_size, embed_size, num_heads, max_len):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        # token数の上限
        self.max_len = max_len
        # このへんでpos encoding
        self.pe = self.positional_encoding(max_len, embed_size) # (max_legnth, embed_size)


        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(embed_size, num_heads, batch_first=True),
            num_layers=2
        )

        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(embed_size, num_heads, batch_first=True),
            num_layers=2
        )
        
        self.lm_head = nn.Linear(embed_size, vocab_size)

    def forward(self, src, tgt): # (b, seq_length)
        """
        pe = 
        [
            [12, 123, 1, 4, 5 ,1 ,4],
            [12, 123, 1, 4, 5 ,1 ,4]
        ]
        [
            [13434, 2343,  3234,...],
            [32432, 343324, 4343,...],
            [],
        ]
        """
        # scr（peは↑のような固定値の配列なので、入力エンべディングのサイズに合わせて切り取る）
        src_p = self.pe[:src.size(1), :].to(src.device) # (seq, embed_dim)
        target_p = self.pe[:src.size(1), :].to(tgt.device)

        # ソースをエンコード
        # batch_first=True なので (batch, seq, embed) のまま
        src_embedded = self.embedding(src) + src_p
        encoded = self.encoder(src_embedded) + target_p
        z
        # ターゲットをデコード
        tgt_embedded = self.embedding(tgt)
        
        # ★追加3: 因果マスク (batch_first なので tgt.size(1) = seq_len)
        tgt_mask = self.generate_square_subsequent_mask(tgt.size(1))
        
        decoded = self.decoder(tgt_embedded, encoded, tgt_mask=tgt_mask)
        output = self.lm_head(decoded)
        
        return output
        
    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz, device=device)) == 1
        mask = mask.transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    
    import math

    def positional_encoding(max_len, embed_size):
        pe = torch.zeros(max_len, embed_size) # [max_length, embedd_size]
        for pos in range(max_len):
            for i in range(0, embed_size, 2):
                pe[pos, i]     = math.sin(pos / (10000 ** (i / embed_size)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** (i / embed_size)))
        return pe

In [13]:
# Load vocab dictionaries (MUST match training order!)
with open('inputLearnText.txt', 'r', encoding='utf-8') as f:
    text = f.read()
chars = sorted(list(set(text)))  # SORTED for consistency with training!
char_to_id = {ch: i for i, ch in enumerate(chars)}
id_to_char = {i: ch for i, ch in enumerate(chars)}

def text_to_ids(text):
    return [char_to_id[ch] for ch in text]

def ids_to_text(ids):
    return ''.join([id_to_char[i] for i in ids])

print(f"Vocab size: {len(chars)}")

Vocab size: 120


In [16]:
# Load model
# NOTE: You need to RETRAIN with the new model that has positional encoding!
# Old model weights won't work because the architecture changed.
model = SimpleGPTPredictor(vocab_size=len(chars), embed_size=32, num_heads=4)
model_config = "model_35.pth"

if device == 'mps':
  model.load_state_dict(torch.load(model_config, map_location=device))
else:
  model.load_state_dict(torch.load(model_config))

model.to(device)
model.eval()

print("Model loaded successfully!")

Model loaded successfully!


In [17]:
def test_prediction(model: SimpleGPTPredictor, input_text, temperature=1.0):
    """
    Predict next character with temperature control.
    
    Args:
        model: The model to use for prediction
        input_text: Input text string
        temperature: Controls randomness. Higher = more random, Lower = more deterministic
                    temperature=1.0 is neutral, >1.0 is more random, <1.0 is more focused
    """
    input_ids = text_to_ids(input_text)
    input_tensor = torch.tensor([input_ids], device=device)

    with torch.no_grad():
        output = model(input_tensor, input_tensor)
        last_char_probs = output[0, -1, :]
        
        # Apply temperature scaling
        last_char_probs = last_char_probs / temperature
        probs = torch.softmax(last_char_probs, dim=-1)

        # Sample from the distribution (instead of always picking top-1)
        if temperature > 0:
            char_id = torch.multinomial(probs, num_samples=1).item()
        else:
            # If temperature is 0, use greedy (deterministic)
            char_id = torch.argmax(probs).item()
            
        predicted_char = id_to_char[char_id]

        return predicted_char

def generateSeq(model, text, max_length=20, temperature=1.0):
    """
    Generate sequence with temperature control.
    
    Args:
        model: The model to use
        text: Starting text
        max_length: Maximum number of tokens to generate
        temperature: Controls randomness (default 1.0)
    """
    generated = text
    for _ in range(max_length):
        nextSingleToken = test_prediction(model, generated, temperature=temperature)
        generated += nextSingleToken
    return generated

In [18]:
prompt = "Their spiritual substance"

# Try different temperatures
temperatures = [0.5, 0.8, 1.0, 1.2, 1.5]

for temp in temperatures:
    completion = generateSeq(model, prompt, max_length=30, temperature=temp)
    print(f"\n=== Temperature: {temp} ===")
    print(f"入力: {prompt}")
    print(f"出力: {completion}")


=== Temperature: 0.5 ===
入力: Their spiritual substance
出力: Their spiritual substance te thert thereinte the the th

=== Temperature: 0.8 ===
入力: Their spiritual substance
出力: Their spiritual substancerible anche witere then, te th

=== Temperature: 1.0 ===
入力: Their spiritual substance
出力: Their spiritual substance t incernleserriteritererincri

=== Temperature: 1.2 ===
入力: Their spiritual substance
出力: Their spiritual substancedited teding the prospo dupthe

=== Temperature: 1.5 ===
入力: Their spiritual substance
出力: Their spiritual substance rthe 
 Hispaprim, Hens, can a


## Model Architecture

### Key Changes:
1. **batch_first=True** - Transformer layers now use (batch, seq, embed) format
2. **No transpose operations** - Simpler code, more efficient
3. **Vocab Consistency** - Both training and inference use sorted() vocab for consistent token IDs
4. **Correct Mask Dimensions** - Uses `tgt.size(1)` for batch-first format

### Architecture:
- Encoder-Decoder Transformer
- Embedding size: 32
- Attention heads: 4
- Layers: 2 (encoder) + 2 (decoder)
- Causal masking for autoregressive generation