In [1]:
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from timm.models.swin_transformer import swin_base_patch4_window7_224 
from timm.data import resolve_data_config, create_transform


In [2]:
# Dataset Class
class CustomDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        try:
            if not os.path.exists(img_path):
                print(f"Image not found: {img_path}. Skipping.")
                return self.__getitem__((idx + 1) % len(self.data))
            image = Image.open(img_path).convert('RGB')  # Convert to RGB
        except Exception as e:
            print(f"Skipping corrupted image: {img_path}, Error: {e}")
            return self.__getitem__((idx + 1) % len(self.data))
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(label, dtype=torch.long)


In [3]:
# Data Module
class CustomDataModule(pl.LightningDataModule):
    def __init__(self, dataset_dir, batch_size=32):
        super().__init__()
        self.dataset_dir = dataset_dir
        self.batch_size = batch_size

        # Use the correct model (swin_base_patch4_window7_224) to resolve the data configuration
        config = resolve_data_config({}, model=swin_base_patch4_window7_224(pretrained=True))
        self.transform = create_transform(**config)

    def prepare_data(self):
        self.data_list = []
        metadata_dir = os.path.join(self.dataset_dir, 'metadata')
        for file_name in os.listdir(metadata_dir):
            metadata_path = os.path.join(metadata_dir, file_name)
            with open(metadata_path, 'r') as file:
                for line in file:
                    tokens = line.strip().split()
                    img_rel_path = tokens[0] + "/" + tokens[1]
                    label = int(tokens[3])
                    img_path = os.path.join(self.dataset_dir, img_rel_path)
                    self.data_list.append((img_path, label))
        print(f"Total valid images loaded: {len(self.data_list)}")

    def setup(self, stage=None):
        dataset = CustomDataset(self.data_list, transform=self.transform)
        train_size = int(0.8 * len(dataset))
        val_size = int(0.1 * len(dataset))
        test_size = len(dataset) - train_size - val_size
        self.train_set, self.val_set, self.test_set = random_split(
            dataset, [train_size, val_size, test_size])

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=4, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=4, pin_memory=True)

In [4]:
class SwinTransformerModule(pl.LightningModule):
    def __init__(self, num_classes=2, learning_rate=1e-4):
        super().__init__()
        self.learning_rate = learning_rate
        self.model = swin_base_patch4_window7_224(pretrained=True)
        self.model.head = torch.nn.Linear(self.model.head.in_features, num_classes)

    def forward(self, x):
        x = self.model.forward_features(x)  # Extract features (B, 7, 7, embed_dim)
        x = x.mean(dim=(1, 2))  # Global average pooling across spatial dimensions (B, embed_dim)
        x = self.model.head(x)  # Classification head (B, num_classes)
        return x

    def training_step(self, batch, batch_idx):
        images, labels = batch
        logits = self(images)
        loss = F.cross_entropy(logits, labels)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        logits = self(images)
        loss = F.cross_entropy(logits, labels)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        images, labels = batch
        logits = self(images)
        loss = F.cross_entropy(logits, labels)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()
        self.log('test_loss', loss)
        self.log('test_acc', acc, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

    # Hook to print training metrics after each epoch
    def on_train_epoch_end(self):
        train_loss = self.trainer.callback_metrics.get('train_loss', None)
        train_acc = self.trainer.callback_metrics.get('train_acc', None)
        if train_loss is not None and train_acc is not None:
            print(f"Epoch {self.current_epoch + 1}: Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")

    # Hook to print validation metrics after each epoch
    def on_validation_epoch_end(self):
        val_loss = self.trainer.callback_metrics.get('val_loss', None)
        val_acc = self.trainer.callback_metrics.get('val_acc', None)
        if val_loss is not None and val_acc is not None:
            print(f"Epoch {self.current_epoch + 1}: Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")

    # Hook to print test metrics after testing
    def on_test_epoch_end(self):
        test_loss = self.trainer.callback_metrics.get('test_loss', None)
        test_acc = self.trainer.callback_metrics.get('test_acc', None)
        if test_loss is not None and test_acc is not None:
            print(f"Test Results: Loss: {test_loss:.4f}, Accuracy: {test_acc:.4f}")


In [5]:
# Main Execution
dataset_dir = '/kaggle/input/lra-pathfinder-32/pathfinder32/curv_contour_length_14'
data_module = CustomDataModule(dataset_dir, batch_size=64)
model = SwinTransformerModule(learning_rate=1e-4)

logger = TensorBoardLogger("tb_logs", name="swin_transformer")

# Callbacks
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath="checkpoints",
    filename="swin-{epoch:02d}-{val_loss:.2f}",
    save_top_k=-1,
    mode="min",
    save_last=True,  # Save the latest checkpoint
)
lr_monitor = LearningRateMonitor(logging_interval='step')

trainer = pl.Trainer(
    max_epochs=5,
    logger=logger,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1,
    callbacks=[checkpoint_callback, lr_monitor],
    precision=16,
)

print("Trainer setup complete. Starting data preparation...")
data_module.prepare_data()
print("Data preparation complete.")

model.safetensors:   0%|          | 0.00/353M [00:00<?, ?B/s]

/opt/conda/lib/python3.10/site-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!


Trainer setup complete. Starting data preparation...
Total valid images loaded: 200000
Data preparation complete.


In [6]:
# Track training time
start_time = time.time()

print("Starting training...")
trainer.fit(model, data_module)

# Calculate training time
end_time = time.time()
training_time = end_time - start_time


Starting training...
Total valid images loaded: 200000


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


Epoch 1: Validation Loss: 0.7185, Validation Accuracy: 0.5312


Training: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1: Validation Loss: 0.6943, Validation Accuracy: 0.4945
Epoch 1: Train Loss: 0.6874, Train Accuracy: 0.5781


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2: Validation Loss: 0.6934, Validation Accuracy: 0.4945
Epoch 2: Train Loss: 0.6964, Train Accuracy: 0.4375


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3: Validation Loss: 0.6932, Validation Accuracy: 0.4945
Epoch 3: Train Loss: 0.6936, Train Accuracy: 0.4688


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4: Validation Loss: 0.6931, Validation Accuracy: 0.4945
Epoch 4: Train Loss: 0.6926, Train Accuracy: 0.6094


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5: Validation Loss: 0.6930, Validation Accuracy: 0.5055
Epoch 5: Train Loss: 0.6945, Train Accuracy: 0.4219


In [7]:
# Test the model
print("Testing the model...")
test_result = trainer.test(model, data_module)

# Calculate Metrics
test_acc = test_result[0]['test_acc']  # Get the test accuracy
num_params = sum(p.numel() for p in model.parameters())
efficiency = test_acc / (torch.log(torch.tensor(training_time)) * torch.log(torch.tensor(num_params)))

# Print Metrics
print("\nTraining Metrics:")
print(f"Time Taken: {training_time:.2f} seconds")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Number of Parameters: {num_params}")
print(f"Efficiency Score: {efficiency:.4f}")


Testing the model...
Total valid images loaded: 200000


Testing: |          | 0/? [00:00<?, ?it/s]

Test Results: Loss: 0.6931, Accuracy: 0.5023



Training Metrics:
Time Taken: 16384.67 seconds
Test Accuracy: 0.5023
Number of Parameters: 86745274
Efficiency Score: 0.0028
