In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import pytorch_lightning as pl

class BinaryCIFARDataset(torch.utils.data.Dataset):
    def __init__(self, original_dataset,indices,label):
        self.original_dataset = original_dataset
        self.indices  = indices
        self.label = label

    def __getitem__(self, index):
        image, _ = self.original_dataset[self.indices[index]]
        return image, self.label

    def __len__(self):
        return len(self.indices)

class BinaryClassifier(pl.LightningModule):
    def __init__(self,  fine_tune=False):
        super(BinaryClassifier, self).__init__()
        self.resnet = torchvision.models.resnet18(pretrained=True)
        if not fine_tune:
            for param in self.resnet.parameters():
                param.requires_grad = False
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, 2)

    def forward(self, x):
        return self.resnet(x)

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        self.log("train_loss", loss, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        _, preds = torch.max(outputs, 1)
        accuracy = torch.sum(preds == labels.data).item() / len(labels)
        self.log("val_loss", loss, on_step=True, on_epoch=True)
        self.log("val_acc", accuracy, on_step=True, on_epoch=True)
        return {"val_loss": loss, "val_acc": accuracy}

    def validation_epoch_end(self, outputs):
        val_losses = [x["val_loss"] for x in outputs]
        val_accs = [x["val_acc"] for x in outputs]

        val_losses = torch.tensor(val_losses)
        val_accs = torch.tensor(val_accs)

        avg_val_loss = val_losses.mean()
        avg_val_acc = val_accs.mean()

        self.log("avg_val_loss", avg_val_loss, on_epoch=True)
        self.log("avg_val_acc", avg_val_acc, on_epoch=True)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=0.001)
        return optimizer

if __name__ == "__main__":
    transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                    transforms.RandomCrop(32, padding=4),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)



            #### train_loader ####
    airplane_indices = [i for i in range(len(train_dataset)) if train_dataset[i][1] == 0]
    non_airplane_indices = [i for i in range(len(train_dataset)) if train_dataset[i][1] != 0]

    airplane_dataset = BinaryCIFARDataset(train_dataset, airplane_indices,label=0 )
    non_airplane_dataset = BinaryCIFARDataset(train_dataset, non_airplane_indices,label=1 )

    train_dataset = torch.utils.data.ConcatDataset([airplane_dataset, non_airplane_dataset])
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)



                    #### test_loader ####
    # test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)
    airplane_indices = [i for i in range(len(test_dataset)) if test_dataset[i][1] == 0]
    non_airplane_indices = [i for i in range(len(test_dataset)) if test_dataset[i][1] != 0]

    airplane_dataset = BinaryCIFARDataset(test_dataset, airplane_indices,label=0 )
    non_airplane_dataset = BinaryCIFARDataset(test_dataset, non_airplane_indices,label=1 )

    test_dataset = torch.utils.data.ConcatDataset([airplane_dataset, non_airplane_dataset])
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True, num_workers=4)

    model = BinaryClassifier()
    trainer = pl.Trainer(max_epochs=1000, gpus=1)  # Set gpus=0 if you don't have a GPU
    trainer.fit(model, train_loader, test_loader)


Files already downloaded and verified
Files already downloaded and verified


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type   | Params
----------------------------------
0 | resnet | ResNet | 11.2 M
----------------------------------
1.0 K     Trainable params
11.2 M    Non-trainable params
11.2 M    Total params
44.710    Total estimated model params size (MB)


Epoch 293:  79%|███████▊  | 738/939 [00:06<00:01, 109.79it/s, loss=0.245, v_num=37]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


Epoch 293:  79%|███████▊  | 738/939 [00:16<00:04, 43.99it/s, loss=0.245, v_num=37] 