In [1]:
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

In [2]:
# Dataset Class
class CustomDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        try:
            if not os.path.exists(img_path):
                print(f"Image not found: {img_path}. Skipping.")
                return self.__getitem__((idx + 1) % len(self.data))
            image = Image.open(img_path).convert('RGB')  # Convert to RGB
        except Exception as e:
            print(f"Skipping corrupted image: {img_path}, Error: {e}")
            return self.__getitem__((idx + 1) % len(self.data))
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(label, dtype=torch.long)


In [3]:
# Data Module
class CustomDataModule(pl.LightningDataModule):
    def __init__(self, dataset_dir, batch_size=32):
        super().__init__()
        self.dataset_dir = dataset_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),  # Resize to 224x224
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def prepare_data(self):
        self.data_list = []
        metadata_dir = os.path.join(self.dataset_dir, 'metadata')
        for file_name in os.listdir(metadata_dir):
            metadata_path = os.path.join(metadata_dir, file_name)
            with open(metadata_path, 'r') as file:
                for line in file:
                    tokens = line.strip().split()
                    img_rel_path = tokens[0] + "/" + tokens[1]
                    label = int(tokens[3])
                    img_path = os.path.join(self.dataset_dir, img_rel_path)
                    self.data_list.append((img_path, label))
        print(f"Total valid images loaded: {len(self.data_list)}")

    def setup(self, stage=None):
        dataset = CustomDataset(self.data_list, transform=self.transform)
        train_size = int(0.8 * len(dataset))
        val_size = int(0.1 * len(dataset))
        test_size = len(dataset) - train_size - val_size
        self.train_set, self.val_set, self.test_set = random_split(
            dataset, [train_size, val_size, test_size])

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=4, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=4, pin_memory=True)


In [4]:
# Patch Embedding Layer
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.projection(x)  # Shape: [B, embed_dim, num_patches^(1/2), num_patches^(1/2)]
        x = x.flatten(2)  # Flatten spatial dimensions: [B, embed_dim, num_patches]
        x = x.transpose(1, 2)  # Shape: [B, num_patches, embed_dim]
        return x


# Transformer Encoder
class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim=768, num_heads=8, depth=6, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=embed_dim,
                nhead=num_heads,
                dim_feedforward=int(embed_dim * mlp_ratio),
                dropout=dropout,
                batch_first=True,
            ) for _ in range(depth)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


In [5]:
# Vision Transformer Model
class VisionTransformer(pl.LightningModule):
    def __init__(self, img_size=224, patch_size=16, num_classes=2, embed_dim=768, depth=6, num_heads=8, mlp_ratio=4.0, learning_rate=1e-4):
        super().__init__()
        self.learning_rate = learning_rate
        self.patch_embedding = PatchEmbedding(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
        self.positional_encoding = nn.Parameter(torch.zeros(1, (img_size // patch_size) ** 2, embed_dim))
        self.transformer = TransformerEncoder(embed_dim=embed_dim, num_heads=num_heads, depth=depth, mlp_ratio=mlp_ratio)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        x += self.positional_encoding
        x = self.transformer(x)
        x = x.mean(dim=1)  # Global average pooling
        return self.head(x)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        logits = self(images)
        loss = F.cross_entropy(logits, labels)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()
        self.log('train_loss', loss)
        self.log('train_acc', acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        logits = self(images)
        loss = F.cross_entropy(logits, labels)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        images, labels = batch
        logits = self(images)
        loss = F.cross_entropy(logits, labels)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()
        self.log('test_loss', loss)
        self.log('test_acc', acc, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)


In [6]:
# Main Execution
dataset_dir = '/kaggle/input/lra-pathfinder-32/pathfinder32/curv_contour_length_14'
data_module = CustomDataModule(dataset_dir, batch_size=64)
model = VisionTransformer(learning_rate=1e-4)

logger = TensorBoardLogger("tb_logs", name="custom_transformer")

# Callbacks
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath="checkpoints",
    filename="transformer-{epoch:02d}-{val_loss:.2f}",
    save_top_k=-1,
    mode="min",
    save_last=True,
)
lr_monitor = LearningRateMonitor(logging_interval='step')

trainer = pl.Trainer(
    max_epochs=5,
    logger=logger,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1,
    callbacks=[checkpoint_callback, lr_monitor],
    precision=16,
)

print("Trainer setup complete. Starting data preparation...")
data_module.prepare_data()
print("Data preparation complete.")

/opt/conda/lib/python3.10/site-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!


Trainer setup complete. Starting data preparation...
Total valid images loaded: 200000
Data preparation complete.


In [7]:
# Track training time
start_time = time.time()

print("Starting training...")
trainer.fit(model, data_module)

# Calculate training time
end_time = time.time()
training_time = end_time - start_time

Starting training...
Total valid images loaded: 200000


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


Training: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [8]:
# Test the model
print("Testing the model...")
test_result = trainer.test(model, data_module)  # Logs metrics in `test_acc` and `test_loss`

# Safely retrieve accuracy
accuracy = trainer.callback_metrics.get('test_acc', None)  # Use `test_acc` for testing
if accuracy is not None:
    accuracy = accuracy.item()
else:
    accuracy = 0.0  # Default to 0 if not found
    print("Warning: 'test_acc' not found in metrics. Defaulting to 0.")

# Safely calculate efficiency
num_params = sum(p.numel() for p in model.parameters())
efficiency = accuracy / (torch.log(torch.tensor(training_time)) * torch.log(torch.tensor(num_params)))

# Print Metrics
print("\nTraining Metrics:")
print(f"Time Taken: {training_time:.2f} seconds")
print(f"Test Accuracy: {accuracy:.4f}")
print(f"Number of Parameters: {num_params}")
print(f"Efficiency Score: {efficiency:.4f}")


Testing the model...
Total valid images loaded: 200000


Testing: |          | 0/? [00:00<?, ?it/s]


Training Metrics:
Time Taken: 7497.51 seconds
Test Accuracy: 0.5053
Number of Parameters: 43269890
Efficiency Score: 0.0032
