In [1]:
import torch
import torch.nn as nn
import torch.optim as optim


In [2]:
class Transformer(nn.Module):
    def __init__(self, input_dim, emb_dim, n_heads, num_layers, ff_dim, output_dim):
        super(Transformer, self).__init__()
        
        # Embedding para la entrada
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.positional_encoding = nn.Parameter(torch.zeros(1, 100, emb_dim))  # 100: Máxima longitud
        
        # Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=n_heads, dim_feedforward=ff_dim)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Capa final para clasificación (o regresión)
        self.fc = nn.Linear(emb_dim, output_dim)

    def forward(self, src):
        # Agregar embeddings y codificación posicional
        src = self.embedding(src) + self.positional_encoding[:, :src.size(1), :]
        
        # Pasar por el codificador Transformer
        encoded = self.encoder(src)
        
        # Promediar la salida y pasar por una capa lineal
        output = self.fc(encoded.mean(dim=1))
        return output


In [3]:
# Parámetros
input_dim = 1000  # Vocabulario
emb_dim = 128
n_heads = 4
num_layers = 3
ff_dim = 256
output_dim = 10  # Por ejemplo, 10 clases


In [4]:
# Crear modelo, pérdida y optimizador
model = Transformer(input_dim, emb_dim, n_heads, num_layers, ff_dim, output_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)




In [5]:
# Datos simulados
x = torch.randint(0, input_dim, (32, 50))  # Batch de 32, secuencia de longitud 50
y = torch.randint(0, output_dim, (32,))   # Etiquetas


In [6]:
# Entrenamiento simple
for epoch in range(10):
    optimizer.zero_grad()
    output = model(x)
    loss = criterion(output, y)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")


Epoch 1, Loss: 2.348327159881592
Epoch 2, Loss: 2.164961338043213
Epoch 3, Loss: 2.1168925762176514
Epoch 4, Loss: 2.0769429206848145
Epoch 5, Loss: 2.0449275970458984
Epoch 6, Loss: 2.0086050033569336
Epoch 7, Loss: 1.9800664186477661
Epoch 8, Loss: 1.9423402547836304
Epoch 9, Loss: 1.8996480703353882
Epoch 10, Loss: 1.8483021259307861
