In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from transformers import SwinForImageClassification, SwinConfig, SwinModel
import torch.nn.functional as F

# Import your dataset and datamodule
from src.dataloader.FireSpreadDataModule import FireSpreadDataModule
from src.dataloader.FireSpreadDataset import FireSpreadDataset

# Define your own model class by extending the pretrained Swin model
class SwinFineTuner(pl.LightningModule):
    def __init__(self):
        super(SwinFineTuner, self).__init__()
        config = SwinConfig.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
        self.model = SwinModel(config)
        
        # Add a final convolutional layer to generate a binary mask
        self.classifier = nn.Conv2d(config.hidden_size, 1, kernel_size=1)
        
        self.criterion = nn.BCEWithLogitsLoss()

    def forward(self, x):
        outputs = self.model(x).last_hidden_state
        outputs = outputs.permute(0, 3, 1, 2)  # Change shape to [batch, hidden_size, height, width]
        logits = self.classifier(outputs)
        return logits

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)
        self.log('val_loss', loss, prog_bar=True)
        return loss

# Parameters
batch_size = 16

# Initialize the data module
data_module = FireSpreadDataModule(batch_size=batch_size)

# Initialize the model
model = SwinFineTuner()

# Set up a model checkpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath="./checkpoints",
    filename="swin-transformer-{epoch:02d}-{val_loss:.2f}",
    save_top_k=1,
    mode="min",
)

# Initialize the trainer
trainer = Trainer(
    max_epochs=1,
    gpus=1,  # Use GPU if available
    callbacks=[checkpoint_callback]
)



ModuleNotFoundError: No module named 'transformers'

In [None]:

# Train the model
trainer.fit(model, datamodule=data_module)
