In [1]:
import os
import time
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics


In [2]:
class PathfinderDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        image = Image.open(img_path).convert('L')  # Grayscale image
        if self.transform:
            image = self.transform(image)
        return image, label

    def collate_fn(self, batch):
        # Custom collate function to ensure consistent tensor shapes
        images, labels = zip(*batch)
        images = torch.stack(list(images))
        labels = torch.tensor(labels)
        return images, labels

In [3]:
class PathfinderDataModule(pl.LightningDataModule):
    def __init__(self, dataset_dir, batch_size=32):
        super().__init__()
        self.dataset_dir = dataset_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))  # Normalize grayscale images
        ])

    def prepare_data(self):
        # Load the metadata and create a data list
        self.data_list = []
        metadata_dir = os.path.join(self.dataset_dir, 'metadata')
        for file_name in os.listdir(metadata_dir):
            metadata_path = os.path.join(metadata_dir, file_name)
            with open(metadata_path, 'r') as file:
                for line in file:
                    tokens = line.strip().split()
                    img_rel_path = tokens[0] + "/" + tokens[1]
                    label = int(tokens[3])  # Assuming label is the fourth value
                    img_path = os.path.join(self.dataset_dir, img_rel_path)
                    self.data_list.append((img_path, label))

    def setup(self, stage=None):
        # Split the data into train, val, and test sets
        dataset = PathfinderDataset(self.data_list, transform=self.transform)
        train_size = int(0.8 * len(dataset))
        val_size = int(0.1 * len(dataset))
        test_size = len(dataset) - train_size - val_size
        self.train_set, self.val_set, self.test_set = random_split(
            dataset, [train_size, val_size, test_size])

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=4, collate_fn=self.train_set.dataset.collate_fn)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=4, collate_fn=self.val_set.dataset.collate_fn)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=4, collate_fn=self.test_set.dataset.collate_fn)


In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1024):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(1)].unsqueeze(0)

class PathfinderTransformer(pl.LightningModule):
    def __init__(self, d_model=256, nhead=8, num_layers=4, dim_feedforward=1024, dropout=0.1):
        super().__init__()
        
        # Initial convolution layers to reduce spatial dimensions and extract features
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, d_model, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(d_model)
        )
        
        # Flatten and positional encoding
        self.pos_encoder = PositionalEncoding(d_model)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=dim_feedforward, 
            dropout=dropout,
            activation='relu'
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, 
            num_layers=num_layers
        )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(d_model, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 2)
        )
        
        # Metrics
        self.train_accuracy = torchmetrics.Accuracy(task='binary')
        self.val_accuracy = torchmetrics.Accuracy(task='binary')
        self.test_accuracy = torchmetrics.Accuracy(task='binary')
        
        # Loss function
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        # Ensure input is 4D (batch, channels, height, width)
        if x.dim() == 2:
            # If input is 2D, assume it's (batch, sequence)
            # Reshape to (batch, channels, height, width)
            x = x.view(-1, 1, 32, 32)
        
        # Extract features from input image
        x = self.feature_extractor(x)
        
        # Reshape to sequence
        x = x.flatten(2)  # (batch, channels, sequence_length)
        x = x.permute(2, 0, 1)  # (sequence_length, batch, channels)
        
        # Add positional encoding
        x = self.pos_encoder(x)
        
        # Pass through transformer encoder
        x = self.transformer_encoder(x)
        
        # Global average pooling over sequence
        x = x.mean(dim=0)
        
        # Classification
        return self.classifier(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        
        # Log accuracy
        preds = torch.argmax(logits, dim=1)
        self.train_accuracy(preds, y)
        
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_accuracy', self.train_accuracy, prog_bar=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        
        # Log accuracy
        preds = torch.argmax(logits, dim=1)
        self.val_accuracy(preds, y)
        
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_accuracy', self.val_accuracy, prog_bar=True)
        
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        
        # Log accuracy
        preds = torch.argmax(logits, dim=1)
        self.test_accuracy(preds, y)
        
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_accuracy', self.test_accuracy, prog_bar=True)
        
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(), 
            lr=1e-3, 
            weight_decay=1e-5
        )
        
        # Learning rate scheduler
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 
            mode='min', 
            factor=0.5, 
            patience=3, 
            verbose=True
        )
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_loss'
            }
        }

# Example of how to use the model
def train_pathfinder_model(dataset_dir):
    # Set up data module
    data_module = PathfinderDataModule(
        dataset_dir=dataset_dir, 
        batch_size=32
    )
    
    # Initialize model
    model = PathfinderTransformer(
        d_model=256, 
        nhead=8, 
        num_layers=4, 
        dim_feedforward=1024, 
        dropout=0.1
    )
    
    # Initialize Lightning Trainer
    trainer = pl.Trainer(
        max_epochs=50,
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices=1 if torch.cuda.is_available() else None,
        precision=16 if torch.cuda.is_available() else 32,
        callbacks=[
            pl.callbacks.EarlyStopping(
                monitor='val_loss', 
                patience=10, 
                mode='min'
            ),
            pl.callbacks.ModelCheckpoint(
                monitor='val_accuracy', 
                mode='max', 
                save_top_k=1
            )
        ]
    )
    
    # Train the model
    start_time = time.time()
    trainer.fit(model, data_module)
    train_time = time.time() - start_time

    print(trainer.callback_metrics)
    
    # Test the model
    trainer.test(model, data_module)
    print(trainer.callback_metrics)
    accuracy = trainer.callback_metrics['test_accuracy']

    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    efficiency = accuracy / (torch.log(torch.tensor(train_time)) * torch.log(torch.tensor(num_params)))

    
    return model, trainer, train_time, accuracy, num_params, efficiency

In [5]:
dataset_dir = '/kaggle/input/lra-pathfinder-32/pathfinder32/curv_contour_length_14'  # Update with your dataset path

model, trainer, train_time, accuracy, num_params, efficiency = train_pathfinder_model(dataset_dir)

/opt/conda/lib/python3.10/site-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


Training: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

{'train_loss': tensor(0.6943), 'train_accuracy': tensor(0.3438), 'val_loss': tensor(0.6931), 'val_accuracy': tensor(0.4992)}


Testing: |          | 0/? [00:00<?, ?it/s]

{'test_loss': tensor(0.6931), 'test_accuracy': tensor(0.5017)}


In [6]:
print("\Metrics:")
print(f"Time Taken: {train_time:.2f} seconds")
print(f"Accuracy: {accuracy:.4f}")
print(f"Number of Parameters: {num_params}")
print(f"Efficiency Score: {efficiency:.4f}")

\Metrics:
Time Taken: 2926.44 seconds
Accuracy: 0.5017
Number of Parameters: 3359426
Efficiency Score: 0.0042
