In [22]:
import torch
import torch.nn as nn
import torch.optim as optim
import lightning.pytorch as pl
from torch.utils.data import TensorDataset, DataLoader

print("torch:", torch.__version__)
print("lightning:", pl.__version__)

torch: 2.3.0+cu121
lightning: 2.3.0


In [23]:
class XorDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=4):
        super().__init__()
        self.batch_size = batch_size
    
    def prepare_data(self):
        self.x = torch.Tensor([[0., 0.], [0., 1.], [1., 0.], [1., 1.]])
        self.y = torch.Tensor([[0.], [1.], [1.], [0.]])
    
    def setup(self, stage = None):
        pass

    def train_dataloader(self):
        ds = TensorDataset(self.x, self.y)
        train_loader = DataLoader(dataset = ds, batch_size = self.batch_size)
        return train_loader

    def predict_dataloader(self):
        ds = TensorDataset(self.x, self.y)
        train_loader = DataLoader(dataset = ds, batch_size = 1)
        return train_loader

In [24]:
class XorClassificationModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.linear_stack = nn.Sequential(
            nn.Linear(2, 4),
            nn.ReLU(),
            nn.Linear(4, 1)
        )
        self.loss = nn.MSELoss()

    def configure_optimizers(self):
        return optim.Adam(model.parameters(), lr=0.01)
        
    def forward(self, inputs):
        return self.linear_stack(inputs)
   
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = self.loss(logits, y)
        self.log("train_loss", loss, prog_bar = True, on_epoch=True, on_step=True)
        return loss

    def predict_step(self, batch, batch_idx):
        x, y = batch
        return self(x)

In [25]:
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

datamodel = XorDataModule()
model = XorClassificationModule()

trainer = pl.Trainer(
    max_epochs = 1000,
    callbacks=[EarlyStopping(monitor="train_loss", mode="min", patience=5, min_delta=0.0001)]
)
trainer.fit(model, datamodule = datamodel)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type       | Params | Mode 
----------------------------------------------------
0 | linear_stack | Sequential | 17     | train
1 | loss         | MSELoss    | 0      | train
----------------------------------------------------
17        Trainable params
0         Non-trainable params
17        Total params
0.000     Total estimated model params size (MB)


Training: |                                                                                      | 0/? [00:00<…

In [26]:
dl = datamodel.train_dataloader()

predictions = trainer.predict(model = model, dataloaders = dl)
y_hat =  torch.round(torch.cat(predictions)).reshape(-1,)

labels = [label for data, label in dl]
y_true = torch.stack(labels).reshape(-1,)
print("expected:", y_true)
print("predicted:", y_hat)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |                                                                                    | 0/? [00:00<…

expected: tensor([0., 1., 1., 0.])
predicted: tensor([-0., 1., 1., 0.])
