In [20]:
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.

import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
import  sys
sys.path.append("C:/Users/isxzl/OneDrive/Code/AutoSSL")
from lightly.loss import BarlowTwinsLoss
from lightly.models.modules import BarlowTwinsProjectionHead
from lightly.transforms.simclr_transform import SimCLRTransform
import torch.distributed as dist

path_to_train_cifar10="../Datasets/cifar10/train/"
path_to_test_cifar10="../Datasets/cifar10/test/"
class BarlowTwins(pl.LightningModule):
    def __init__(self,MonitoringbyKNN=None):
        super().__init__()
        resnet = torchvision.models.resnet18()
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])
        self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)
        self.criterion = BarlowTwinsLoss()

        if MonitoringbyKNN:
            self.dataloader_kNN = MonitoringbyKNN[0]
            self.num_classes = MonitoringbyKNN[1]
            self.knn_k = 200 
            self.knn_t = 0.1
            self.max_accuracy = 0.0
            self._train_features: Optional[Tensor] = None
            self._train_targets: Optional[Tensor] = None
            self._val_predicted_labels: List[Tensor] = []
            self._val_targets: List[Tensor] = []
        else:
            self.dataloader_kNN=None
        
        
    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        
        return z

    def training_step(self, batch, batch_index):
        (x0, x1) = batch[0]
        z0 = self.forward(x0)
        z1 = self.forward(x1)
        loss = self.criterion(z0, z1)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        optim = torch.optim.SGD(self.parameters(), lr=0.06)
        return optim

    def on_validation_epoch_start(self) -> None:
        if self.dataloader_kNN:
            train_features = []
            train_targets = []
            with torch.no_grad():
                for data in self.dataloader_kNN:
                    img, target, _ = data
                    img = img.to(self.device)
                    target = target.to(self.device)
                    feature = self.backbone(img).squeeze()
                    feature = F.normalize(feature, dim=1)
                    if (
                        dist.is_available()
                        and dist.is_initialized()
                        and dist.get_world_size() > 0
                    ):
                        # gather features and targets from all processes
                        feature = torch.cat(dist.gather(feature), 0)
                        target = torch.cat(dist.gather(target), 0)
                    train_features.append(feature)
                    train_targets.append(target)
            self._train_features = torch.cat(train_features, dim=0).t().contiguous()
            self._train_targets = torch.cat(train_targets, dim=0).t().contiguous()

    def validation_step(self, batch, batch_idx) -> None:
        if self.dataloader_kNN:
            # we can only do kNN predictions once we have a feature bank
            if self._train_features is not None and self._train_targets is not None:
                images, targets, _ = batch
                feature = self.backbone(images).squeeze()
                feature = F.normalize(feature, dim=1)
                predicted_labels = knn_predict(
                    feature,
                    self._train_features,
                    self._train_targets,
                    self.num_classes,
                    self.knn_k,
                    self.knn_t,
                )
                if dist.is_initialized() and dist.get_world_size() > 0:
                    # gather predictions and targets from all processes
                    predicted_labels = torch.cat(dist.gather(predicted_labels), 0)
                    targets = torch.cat(dist.gather(targets), 0)

                self._val_predicted_labels.append(predicted_labels.cpu())
                self._val_targets.append(targets.cpu())

    def on_validation_epoch_end(self) -> None:
        if self.dataloader_kNN:
            if self._val_predicted_labels and self._val_targets:
                predicted_labels = torch.cat(self._val_predicted_labels, dim=0)
                targets = torch.cat(self._val_targets, dim=0)
                top1 = (predicted_labels[:, 0] == targets).float().sum()
                acc = top1 / len(targets)
                if acc > self.max_accuracy:
                    self.max_accuracy = acc.item()
                self.log("kNN_accuracy", acc * 100.0, prog_bar=True)

            self._val_predicted_labels.clear()
            self._val_targets.clear()
            
from lightly.data import LightlyDataset
import torch.nn as nn
import torchvision
test_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ),
    ]
)

transform = SimCLRTransform(input_size=32)

dataset = LightlyDataset(input_dir=path_to_train_cifar10, transform=transform)

testdataset = LightlyDataset(input_dir=path_to_test_cifar10, transform=test_transforms)


dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=512,
    shuffle=True,
    drop_last=True,
    num_workers=4,
)

testdataloader = torch.utils.data.DataLoader(
    testdataset,
    batch_size=512,
    shuffle=False,
    drop_last=False,
    num_workers=4,
)            
            

model = BarlowTwins(MonitoringbyKNN=[testdataloader,10])




# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)



In [22]:
from autoSSL.utils import ck_callback, join_dir,ContinuousCSVLogger
import torch.nn.functional as F
from autoSSL.utils.knn import knn_predict
dirr="experiment_checkpoints/vanilla/" 
accelerator = "gpu"
trainer = pl.Trainer(max_epochs=50, devices=1, 
                     accelerator=accelerator,
                     callbacks=[ck_callback(dirr)],
                     logger=ContinuousCSVLogger(save_dir=dirr),check_val_every_n_epoch =5)

trainer.fit(model=model, train_dataloaders=dataloader,val_dataloaders=testdataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type                      | Params
--------------------------------------------------------------
0 | backbone        | Sequential                | 11.2 M
1 | projection_head | BarlowTwinsProjectionHead | 9.4 M 
2 | criterion       | BarlowTwinsLoss           | 0     
--------------------------------------------------------------
20.6 M    Trainable params
0         Non-trainable params
20.6 M    Total params
82.496

Epoch 0: 100%|████████████████████████████████████████████████████████████████| 97/97 [00:26<00:00,  3.69it/s, v_num=3]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                                               | 0/20 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                  | 0/20 [00:00<?, ?it/s][A
Validation DataLoader 0:   5%|██▉                                                       | 1/20 [00:00<00:00, 43.62it/s][A
Validation DataLoader 0:  10%|█████▊                                                    | 2/20 [00:00<00:00, 40.53it/s][A
Validation DataLoader 0:  15%|████████▋                                                 | 3/20 [00:00<00:00, 36.93it/s][A
Validation DataLoader 0:  20%|███████████▌                                              | 4/20 [00:00<00:00, 32.49it/s][A
Validation DataLoader 0:  25%|██████████████▌                                           | 5/20 [00:00<00:00,