<a href="https://colab.research.google.com/github/kyungh2e2e/CapstoneDesignProject/blob/main/Lin_VS_Trans.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision torchaudio torchtext linformer

Collecting torchtext
  Downloading torchtext-0.18.0-cp310-cp310-manylinux1_x86_64.whl.metadata (7.9 kB)
Downloading torchtext-0.18.0-cp310-cp310-manylinux1_x86_64.whl (2.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m26.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torchtext
Successfully installed torchtext-0.18.0


In [None]:
import torch
from linformer import Linformer

model = Linformer(
    dim = 512,
    seq_len = 4096,
    depth = 12,
    heads = 8,
    k = 256,
    one_kv_head = True,
    share_kv = True
)

x = torch.randn(1, 4096, 512)
model(x) # (1, 4096, 512)

tensor([[[-2.0257,  1.5839, -0.1621,  ..., -0.2654, -0.0620, -0.6689],
         [-0.7153, -0.9838,  0.8176,  ...,  3.7064,  1.8369,  1.3344],
         [ 0.4395, -1.3652,  1.0835,  ..., -1.3006,  1.3377,  2.5808],
         ...,
         [-0.2495, -0.1632,  0.1230,  ..., -0.5743, -1.5128, -1.1477],
         [ 0.3779,  1.7093, -0.4358,  ..., -0.2257, -1.7048, -0.4126],
         [ 2.0530, -1.8306, -2.1278,  ..., -0.1658,  0.4045,  0.0821]]],
       grad_fn=<AddBackward0>)

In [None]:
!pip install torch torchvision linformer



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from linformer import Linformer
import time

# Hyperparameters
SEQ_LEN = 28  # MNIST 이미지는 28x28, 세로줄을 시퀀스로 처리
EMBED_DIM = 128
N_CLASSES = 10
BATCH_SIZE = 64
EPOCHS = 3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data Preparation (MNIST)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Models
class TransformerModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Linear(28, EMBED_DIM)  # 프로젝트 28 픽셀 -> EMBED_DIM
        self.positional_encoding = nn.Parameter(torch.zeros(SEQ_LEN, EMBED_DIM))
        self.transformer = nn.Transformer(d_model=EMBED_DIM, nhead=4, num_encoder_layers=3)
        self.fc = nn.Linear(EMBED_DIM, N_CLASSES)

    def forward(self, x):
        x = self.embedding(x) + self.positional_encoding
        x = x.permute(1, 0, 2)  # (seq_len, batch, embed_dim)
        out = self.transformer(x, x).mean(dim=0)  # 평균 풀링
        return self.fc(out)

class LinformerModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Linear(28, EMBED_DIM)
        self.positional_encoding = nn.Parameter(torch.zeros(SEQ_LEN, EMBED_DIM))
        self.linformer = Linformer(dim=EMBED_DIM, seq_len=SEQ_LEN, depth=3, heads=4, k=16)
        self.fc = nn.Linear(EMBED_DIM, N_CLASSES)

    def forward(self, x):
        x = self.embedding(x) + self.positional_encoding
        out = self.linformer(x)
        out = out.mean(dim=1)  # 평균 풀링
        return self.fc(out)

# Training and Evaluation Functions
def train(model, dataloader, optimizer, criterion):
    model.train()
    total_loss = 0
    start_time = time.time()
    for x, y in dataloader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        x = x.view(x.size(0), 28, 28)  # Reshape to (batch, seq_len, input_dim)
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    elapsed = time.time() - start_time
    return total_loss / len(dataloader), elapsed

def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            x = x.view(x.size(0), 28, 28)
            output = model(x)
            loss = criterion(output, y)
            total_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += (pred == y).sum().item()
    accuracy = correct / len(dataloader.dataset)
    return total_loss / len(dataloader), accuracy

# Initialize models, optimizers, and loss function
transformer_model = TransformerModel().to(DEVICE)
linformer_model = LinformerModel().to(DEVICE)

criterion = nn.CrossEntropyLoss()
transformer_optimizer = optim.Adam(transformer_model.parameters(), lr=1e-3)
linformer_optimizer = optim.Adam(linformer_model.parameters(), lr=1e-3)

# Train and evaluate Transformer Model
print("Training Transformer Model...")
for epoch in range(EPOCHS):
    train_loss, train_time = train(transformer_model, train_loader, transformer_optimizer, criterion)
    test_loss, test_accuracy = evaluate(transformer_model, test_loader, criterion)
    print(f"Epoch {epoch + 1}: Train Loss = {train_loss:.4f}, Train Time = {train_time:.2f}s, Test Loss = {test_loss:.4f}, Test Accuracy = {test_accuracy:.4f}")

# Train and evaluate Linformer Model
print("\nTraining Linformer Model...")
for epoch in range(EPOCHS):
    train_loss, train_time = train(linformer_model, train_loader, linformer_optimizer, criterion)
    test_loss, test_accuracy = evaluate(linformer_model, test_loader, criterion)
    print(f"Epoch {epoch + 1}: Train Loss = {train_loss:.4f}, Train Time = {train_time:.2f}s, Test Loss = {test_loss:.4f}, Test Accuracy = {test_accuracy:.4f}")


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 17.9MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 475kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 3.80MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 4.14MB/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Training Transformer Model...
Epoch 1: Train Loss = 1.3077, Train Time = 46.08s, Test Loss = 0.6927, Test Accuracy = 0.7891
Epoch 2: Train Loss = 0.5529, Train Time = 45.15s, Test Loss = 0.4427, Test Accuracy = 0.8782
Epoch 3: Train Loss = 0.3232, Train Time = 44.97s, Test Loss = 0.2061, Test Accuracy = 0.9450

Training Linformer Model...
Epoch 1: Train Loss = 0.2377, Train Time = 21.78s, Test Loss = 0.1294, Test Accuracy = 0.9611
Epoch 2: Train Loss = 0.1082, Train Time = 21.78s, Test Loss = 0.0920, Test Accuracy = 0.9729
Epoch 3: Train Loss = 0.0866, Train Time = 21.82s, Test Loss = 0.0821, Test Accuracy = 0.9728
