In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy
import os
import zipfile
import requests
from tqdm import tqdm
import pandas as pd
import torch
import numpy as np
from transformers import AutoTokenizer
from sklearn.model_selection import train_test_split
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset

In [2]:
# this notebook run on apple silicon chip, in order to leverage GPU:
print(torch.__version__)
print(torch.backends.mps.is_available())

2.6.0
True


In [21]:
## set device as mps
device = torch.device("mps")

In [26]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        # Ensure that the model dimension (d_model) is divisible by the number of heads
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        # Initialize dimensions
        self.d_model = d_model # Model's dimension
        self.num_heads = num_heads # Number of attention heads
        self.d_k = d_model // num_heads # Dimension of each head's key, query, and value
        
        # Linear layers for transforming inputs
        self.W_q = nn.Linear(d_model, d_model) # Query transformation
        self.W_k = nn.Linear(d_model, d_model) # Key transformation
        self.W_v = nn.Linear(d_model, d_model) # Value transformation
        self.W_o = nn.Linear(d_model, d_model) # Output transformation
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Calculate attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # Apply mask if provided (useful for preventing attention to certain parts like padding)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        
        # Softmax is applied to obtain attention probabilities
        attn_probs = torch.softmax(attn_scores, dim=-1)
        
        # Multiply by values to obtain the final output
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        # Reshape the input to have num_heads for multi-head attention
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        # Combine the multiple heads back to original shape
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        # Apply linear transformations and split heads
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        # Perform scaled dot-product attention
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # Combine heads and apply output transformation
        output = self.W_o(self.combine_heads(attn_output))
        return output

In [27]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

In [49]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length, device=None):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model, device=device)
        position = torch.arange(0, max_seq_length, dtype=torch.float, device=device).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        div_term = div_term.to(device)
        print("Input device:", pe.device)
        print("Target device:", position.device)
        print("Target device:", div_term.device)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

In [50]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [51]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

In [52]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length, device=device)

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2).to(device)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3).to(device)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool().to(device)
#         print (tgt_mask.shape)
#         print (nopeak_mask.shape)
        tgt_mask = tgt_mask & nopeak_mask

        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output

In [53]:
src_vocab_size = 32000
tgt_vocab_size = 32000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 20
dropout = 0.1

transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)


Input device: mps:0
Target device: mps:0
Target device: mps:0


In [12]:
## Read and preprocess training, test data

# Constants
DOWNLOAD_URL = "https://opus.nlpl.eu/download.php?f=OpenSubtitles/v2018/moses/en-es.txt.zip"
EXTRACT_DIR = "/Users/fangzhang/Desktop/build_transformer/en-es.txt"
ZIP_FILE = "en-es.txt.zip"
OUTPUT_FILE = "opensubtitles_en_es.tsv"

# Step 1: Read and clean the sentence pairs

src_file_path = os.path.join(EXTRACT_DIR, "OpenSubtitles.en-es.en")
tgt_file_path = os.path.join(EXTRACT_DIR, "OpenSubtitles.en-es.es")

print("Processing sentence pairs...")
with open(src_file_path, 'r', encoding='utf-8') as src_file, \
     open(tgt_file_path, 'r', encoding='utf-8') as tgt_file, \
     open(OUTPUT_FILE, 'w', encoding='utf-8') as out_file:
    for src_line, tgt_line in tqdm(zip(src_file, tgt_file), desc="Cleaning", unit="lines"):
        src_line = src_line.strip()
        tgt_line = tgt_line.strip()
        if src_line and tgt_line and len(src_line.split()) < 50 and len(tgt_line.split()) < 50:
            out_file.write(f"{src_line}\t{tgt_line}\n")

print(f"\nSaved cleaned sentence pairs to {OUTPUT_FILE} ✅")


# Step 2: Load the TSV file with error handling
df = pd.read_csv(
    'opensubtitles_en_es.tsv',
    sep='\t',
    header=None,
    names=['en', 'es'],
    engine='python',
    on_bad_lines='warn'  # Skips bad lines and warns about them
)

# Step 3: processing
# Drop rows with null values
df.dropna(inplace=True)

# Filter sentence pairs by length
df = df[df['en'].str.split().str.len().between(3, 50)]
df = df[df['es'].str.split().str.len().between(3, 50)]

# Remove duplicate pairs
df.drop_duplicates(inplace=True)

# Reset index
df.reset_index(drop=True, inplace=True)


# Step 4: Initialize tokenizers for English and Spanish
tokenizer_en = AutoTokenizer.from_pretrained('bert-base-uncased')
tokenizer_es = AutoTokenizer.from_pretrained('dccuchile/bert-base-spanish-wwm-cased')

# Tokenize the sentences
df['en_tokens'] = df['en'].apply(lambda x: tokenizer_en.encode(x, add_special_tokens=True))
df['es_tokens'] = df['es'].apply(lambda x: tokenizer_es.encode(x, add_special_tokens=True))



# Step 5: Split into training and temp (which will be further split into validation and test)
train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42)

# Split temp into validation and test
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)



# Example pandas Series
# Each value is a list of integers representing tokens

# Parameters
max_len = 20
pad_token = 0  # You can choose any value as the padding token

# Step 6: convert to tensors

def gen_tensor(s):

    # Pad or truncate each list
    padded = s.apply(lambda x: (x + [pad_token]*max_len)[:max_len])

    # Convert to tensor
    data = torch.tensor(padded.tolist(), dtype=torch.long)
    return data

src_data = gen_tensor(train_df['en_tokens'])
tgt_data = gen_tensor(train_df['es_tokens'])

src_test = gen_tensor(test_df['en_tokens'])
tgt_test = gen_tensor(test_df['es_tokens'])



Processing sentence pairs...


Cleaning: 506188lines [00:00, 530722.85lines/s]
Skipping line 589: '	' expected after '"'
Skipping line 1782: '	' expected after '"'
Skipping line 2129: '	' expected after '"'
Skipping line 2352: '	' expected after '"'
Skipping line 2353: '	' expected after '"'
Skipping line 2354: '	' expected after '"'
Skipping line 2580: '	' expected after '"'
Skipping line 2882: '	' expected after '"'
Skipping line 2883: '	' expected after '"'
Skipping line 2908: '	' expected after '"'
Skipping line 2910: '	' expected after '"'
Skipping line 2911: '	' expected after '"'
Skipping line 3355: '	' expected after '"'
Skipping line 3703: '	' expected after '"'
Skipping line 4442: '	' expected after '"'
Skipping line 4696: '	' expected after '"'
Skipping line 4758: '	' expected after '"'
Skipping line 4759: '	' expected after '"'
Skipping line 4760: '	' expected after '"'
Skipping line 4841: '	' expected after '"'
Skipping line 4842: '	' expected after '"'
Skipping line 4843: '	' expected after '"'
Skippin


Saved cleaned sentence pairs to opensubtitles_en_es.tsv ✅


Skipping line 56478: '	' expected after '"'
Skipping line 56483: '	' expected after '"'
Skipping line 56488: '	' expected after '"'
Skipping line 56643: '	' expected after '"'
Skipping line 56785: '	' expected after '"'
Skipping line 57653: '	' expected after '"'
Skipping line 59496: '	' expected after '"'
Skipping line 59943: '	' expected after '"'
Skipping line 59990: '	' expected after '"'
Skipping line 60333: '	' expected after '"'
Skipping line 60651: '	' expected after '"'
Skipping line 60652: '	' expected after '"'
Skipping line 60654: '	' expected after '"'
Skipping line 60655: '	' expected after '"'
Skipping line 60658: '	' expected after '"'
Skipping line 60659: '	' expected after '"'
Skipping line 60660: '	' expected after '"'
Skipping line 60661: '	' expected after '"'
Skipping line 60662: '	' expected after '"'
Skipping line 60663: '	' expected after '"'
Skipping line 60668: '	' expected after '"'
Skipping line 60669: '	' expected after '"'
Skipping line 60670: '	' expecte

Skipping line 154743: '	' expected after '"'
Skipping line 154746: '	' expected after '"'
Skipping line 154788: '	' expected after '"'
Skipping line 154962: '	' expected after '"'
Skipping line 155169: '	' expected after '"'
Skipping line 155404: '	' expected after '"'
Skipping line 155405: '	' expected after '"'
Skipping line 155406: '	' expected after '"'
Skipping line 155407: '	' expected after '"'
Skipping line 155408: '	' expected after '"'
Skipping line 155409: '	' expected after '"'
Skipping line 155410: '	' expected after '"'
Skipping line 155411: '	' expected after '"'
Skipping line 155412: '	' expected after '"'
Skipping line 155413: '	' expected after '"'
Skipping line 155415: '	' expected after '"'
Skipping line 155416: '	' expected after '"'
Skipping line 155417: '	' expected after '"'
Skipping line 156385: '	' expected after '"'
Skipping line 156806: '	' expected after '"'
Skipping line 157119: '	' expected after '"'
Skipping line 157120: '	' expected after '"'
Skipping l

Skipping line 266519: '	' expected after '"'
Skipping line 266538: '	' expected after '"'
Skipping line 266624: '	' expected after '"'
Skipping line 266644: '	' expected after '"'
Skipping line 266646: '	' expected after '"'
Skipping line 266818: '	' expected after '"'
Skipping line 266850: '	' expected after '"'
Skipping line 266855: '	' expected after '"'
Skipping line 266885: '	' expected after '"'
Skipping line 266978: '	' expected after '"'
Skipping line 267131: '	' expected after '"'
Skipping line 267795: '	' expected after '"'
Skipping line 267835: '	' expected after '"'
Skipping line 267994: '	' expected after '"'
Skipping line 268055: '	' expected after '"'
Skipping line 268368: '	' expected after '"'
Skipping line 269012: '	' expected after '"'
Skipping line 269013: '	' expected after '"'
Skipping line 269023: '	' expected after '"'
Skipping line 269033: '	' expected after '"'
Skipping line 269109: '	' expected after '"'
Skipping line 269333: '	' expected after '"'
Skipping l

Skipping line 411272: '	' expected after '"'
Skipping line 411312: '	' expected after '"'
Skipping line 411314: '	' expected after '"'
Skipping line 411375: '	' expected after '"'
Skipping line 411633: '	' expected after '"'
Skipping line 411691: '	' expected after '"'
Skipping line 411706: '	' expected after '"'
Skipping line 411722: '	' expected after '"'
Skipping line 411789: '	' expected after '"'
Skipping line 411790: '	' expected after '"'
Skipping line 411791: '	' expected after '"'
Skipping line 411795: '	' expected after '"'
Skipping line 412000: '	' expected after '"'
Skipping line 412002: '	' expected after '"'
Skipping line 412026: '	' expected after '"'
Skipping line 412225: '	' expected after '"'
Skipping line 412316: '	' expected after '"'
Skipping line 412390: '	' expected after '"'
Skipping line 412424: '	' expected after '"'
Skipping line 412428: '	' expected after '"'
Skipping line 412429: '	' expected after '"'
Skipping line 412581: '	' expected after '"'
Skipping l

Skipping line 13493: Expected 2 fields in line 13493, saw 3
Skipping line 41578: Expected 2 fields in line 41578, saw 3
Skipping line 65714: Expected 2 fields in line 65714, saw 3
Skipping line 72980: Expected 2 fields in line 72980, saw 3
Skipping line 74945: Expected 2 fields in line 74945, saw 3
Skipping line 104157: Expected 2 fields in line 104157, saw 3
Skipping line 105090: Expected 2 fields in line 105090, saw 3
Skipping line 110553: Expected 2 fields in line 110553, saw 3
Skipping line 131137: Expected 2 fields in line 131137, saw 3
Skipping line 131180: Expected 2 fields in line 131180, saw 3
Skipping line 131245: Expected 2 fields in line 131245, saw 3
Skipping line 131495: Expected 2 fields in line 131495, saw 3
Skipping line 131520: Expected 2 fields in line 131520, saw 3
Skipping line 131536: Expected 2 fields in line 131536, saw 3
Skipping line 131557: Expected 2 fields in line 131557, saw 3
Skipping line 131565: Expected 2 fields in line 131565, saw 3
Skipping line 1315

In [14]:
src_data.shape

torch.Size([284734, 20])

In [54]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

model = transformer
model.to(device)

# transformer.train()
model.train()

Transformer(
  (encoder_embedding): Embedding(32000, 512)
  (decoder_embedding): Embedding(32000, 512)
  (positional_encoding): PositionalEncoding()
  (encoder_layers): ModuleList(
    (0-5): 6 x EncoderLayer(
      (self_attn): MultiHeadAttention(
        (W_q): Linear(in_features=512, out_features=512, bias=True)
        (W_k): Linear(in_features=512, out_features=512, bias=True)
        (W_v): Linear(in_features=512, out_features=512, bias=True)
        (W_o): Linear(in_features=512, out_features=512, bias=True)
      )
      (feed_forward): PositionWiseFeedForward(
        (fc1): Linear(in_features=512, out_features=2048, bias=True)
        (fc2): Linear(in_features=2048, out_features=512, bias=True)
        (relu): ReLU()
      )
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (decoder_layers): ModuleList(
    (0-5): 6 x DecoderLayer(

In [56]:
dataset = TensorDataset(src_data, tgt_data)

# Create DataLoader with batching
loader = DataLoader(dataset, batch_size=64, shuffle=True)

# Usage in training loop
for batch in loader:
    src_batch, tgt_batch = batch  # both are shape (batch_size, seq_len)
    src_batch, tgt_batch = src_batch.to(device), tgt_batch.to(device)
    # Feed into transformer model here
    
    for epoch in range(2):
        optimizer.zero_grad()
        output = model(src_batch, tgt_batch[:, :-1])
        loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_batch[:, 1:].contiguous().view(-1))
        loss.backward()
        optimizer.step()
        print(f"Epoch: {epoch+1}, Loss: {loss.item()}")


In [25]:
transformer.eval()

src_test = gen_tensor(test_df.iloc[:1900,2])
tgt_test = gen_tensor(test_df.iloc[:1900,3])
src_test, tgt_test = src_test.to(device), tgt_test.to(device)

with torch.no_grad():

    val_output = transformer(src_test, tgt_test[:, :-1])
    val_loss = criterion(val_output.contiguous().view(-1, tgt_vocab_size), tgt_test[:, 1:].contiguous().view(-1))
    print(f"Validation Loss: {val_loss.item()}")

Validation Loss: 3.561875820159912
