In [1]:
!pip install einops

Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0


In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from einops import rearrange, repeat

In [3]:
# Dataset Class
class CustomDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        try:
            if not os.path.exists(img_path):
                print(f"Image not found: {img_path}. Skipping.")
                return self.__getitem__((idx + 1) % len(self.data))
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Skipping corrupted image: {img_path}, Error: {e}")
            return self.__getitem__((idx + 1) % len(self.data))
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(label, dtype=torch.long)



In [4]:
# Data Module
class CustomDataModule(pl.LightningDataModule):
    def __init__(self, dataset_dir, batch_size=32):
        super().__init__()
        self.dataset_dir = dataset_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def prepare_data(self):
        self.data_list = []
        metadata_dir = os.path.join(self.dataset_dir, 'metadata')
        for file_name in os.listdir(metadata_dir):
            metadata_path = os.path.join(metadata_dir, file_name)
            with open(metadata_path, 'r') as file:
                for line in file:
                    tokens = line.strip().split()
                    img_rel_path = tokens[0] + "/" + tokens[1]
                    label = int(tokens[3])
                    img_path = os.path.join(self.dataset_dir, img_rel_path)
                    self.data_list.append((img_path, label))
        print(f"Total valid images loaded: {len(self.data_list)}")

    def setup(self, stage=None):
        dataset = CustomDataset(self.data_list, transform=self.transform)
        train_size = int(0.8 * len(dataset))
        val_size = int(0.1 * len(dataset))
        test_size = len(dataset) - train_size - val_size
        self.train_set, self.val_set, self.test_set = random_split(
            dataset, [train_size, val_size, test_size])

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=4, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=4, pin_memory=True)


In [5]:
# Transformer Model
class ImageTransformer(pl.LightningModule):
    def __init__(self, img_size=224, patch_size=16, num_classes=2, embed_dim=768, depth=6, num_heads=8, mlp_ratio=4.0, learning_rate=1e-4):
        super().__init__()
        self.learning_rate = learning_rate

        # Patch Embedding
        self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
        num_patches = (img_size // patch_size) ** 2

        # Positional Encoding
        self.positional_encoding = nn.Parameter(torch.zeros(1, num_patches, embed_dim))

        # Transformer Encoder
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=int(embed_dim * mlp_ratio),
            activation='gelu',
        )
        self.transformer = nn.TransformerEncoder(self.encoder_layer, num_layers=depth)

        # Classification Head
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        # Convert image to patches
        x = self.patch_embed(x)  # Shape: (B, embed_dim, H/patch_size, W/patch_size)
        x = rearrange(x, "b c h w -> b (h w) c")  # Flatten patches

        # Add positional encoding
        x += self.positional_encoding

        # Add classification token
        cls_tokens = repeat(self.cls_token, "1 1 c -> b 1 c", b=x.size(0))
        x = torch.cat([cls_tokens, x], dim=1)

        # Transformer Encoding
        x = self.transformer(x)

        # Classification head
        cls_output = x[:, 0]  # Take the CLS token
        return self.head(cls_output)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        logits = self(images)
        loss = F.cross_entropy(logits, labels)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()
        self.log('train_loss', loss)
        self.log('train_acc', acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        logits = self(images)
        loss = F.cross_entropy(logits, labels)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        images, labels = batch
        logits = self(images)
        loss = F.cross_entropy(logits, labels)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()
        self.log('test_loss', loss)
        self.log('test_acc', acc, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)



In [6]:
# Main Execution
dataset_dir = '/kaggle/input/lra-pathfinder-32/pathfinder32/curv_contour_length_14'
data_module = CustomDataModule(dataset_dir, batch_size=64)
model = ImageTransformer(learning_rate=1e-4)

logger = TensorBoardLogger("tb_logs", name="image_transformer")

# Callbacks
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath="checkpoints",
    filename="image-transformer-{epoch:02d}-{val_loss:.2f}",
    save_top_k=-1,
    mode="min",
    save_last=True,
)
lr_monitor = LearningRateMonitor(logging_interval='step')

trainer = pl.Trainer(
    max_epochs=8,
    logger=logger,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1,
    callbacks=[checkpoint_callback, lr_monitor],
    precision=16,
)

print("Trainer setup complete. Starting data preparation...")
data_module.prepare_data()
print("Data preparation complete.")

/opt/conda/lib/python3.10/site-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!


Trainer setup complete. Starting data preparation...
Total valid images loaded: 200000
Data preparation complete.


In [7]:
print("Starting training...")
trainer.fit(model, data_module)

Starting training...
Total valid images loaded: 200000


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


Training: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [8]:
print("Testing the model...")
trainer.test(model, data_module)

# Calculate Metrics
test_result = trainer.callback_metrics
test_acc = test_result.get('test_acc', None)
if test_acc is not None:
    test_acc = test_acc.item()
else:
    test_acc = 0.0

# Additional Metrics
training_time = 1  # Replace with actual training time
num_params = sum(p.numel() for p in model.parameters())
efficiency = test_acc / (torch.log(torch.tensor(training_time)) * torch.log(torch.tensor(num_params)))

# Print Metrics
print("\nTraining Metrics:")
print(f"Time Taken: {training_time:.2f} seconds")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Number of Parameters: {num_params}")
print(f"Efficiency Score: {efficiency:.4f}")

Testing the model...
Total valid images loaded: 200000


Testing: |          | 0/? [00:00<?, ?it/s]


Training Metrics:
Time Taken: 1.00 seconds
Test Accuracy: 0.4976
Number of Parameters: 50358530
Efficiency Score: inf
