In [1]:
import torch
import torchvision
from data import get_color_distortion
from model import SimCLR, NTXentLoss
from torchvision import transforms
import time
import wandb
from datasets import load_dataset
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR

if torch.cuda.is_available():
    device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print("Training on", device)

# Define a pair transformation that applies the transformation only for one branch of the model
class PairTransform:
    def __init__(self, transform):
        self.transform = transform
        self.original_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def __call__(self, x):
        return self.original_transform(x), self.transform(x)

def run_experiment(config, transform):
    if config["dataset"] == "CIFAR-10":
        train_dataset = torchvision.datasets.CIFAR10(root='./data', transform=PairTransform(transform), download=True)

    trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True, drop_last=True)

    model = SimCLR(resnet=config["resnet"], out_dim=config["projection_dim"], projection=config["projection"]).to(device)
    criterion = NTXentLoss(config["batch_size"], device, config["temperature"])
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
    scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=config["warmup_epochs"], warmup_start_lr=config["learning_rate"] * 1/10, max_epochs=config["epochs"], eta_min=0)

    wandb.init(
        project="SimCLR-augmentation-experiments",
        config=config
    )

    for epoch in range(1, config["epochs"]+1):  # Start epoch numbering from 1
        model.train()
        epoch_start_time = time.time()
        total_loss = 0
        for i, (images, labels) in enumerate(trainloader):
            optimizer.zero_grad()
            h_i, h_j, z_i, z_j = model(images[0].to(device), images[1].to(device))
            loss = criterion(z_i, z_j)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        wandb.log({"lr": scheduler.get_last_lr()}, step=epoch)

        scheduler.step()
        avg_loss = total_loss / len(trainloader)
        # log average loss for epoch to wandb
        wandb.log({"loss": avg_loss}, step=epoch)
        
        print(f'Epoch {epoch} finished in {time.time() - epoch_start_time} seconds, Average Loss: {avg_loss}')

    torch.save(model, f'./models/augmentation_experiments/{config["transform"]}.pth')
    wandb.finish()
    print("Training completed and logs saved.")
    

augmentations = {
    "random_crop": transforms.RandomResizedCrop(32, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)),
    "color_distortion": get_color_distortion(),
    "gaussian_blur": transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
    "shear": transforms.RandomAffine(degrees=0, shear=30),
    "elastic": transforms.ElasticTransform()
}


config = {
    "batch_size": 2048,
    "epochs": 200,
    "learning_rate": 3e-4,
    "weight_decay": 1e-4,
    "momentum": 0.9,
    "projection_dim": 128,
    "projection": "nonlinear",
    "temperature": 0.5,
    "resnet": 18,
    "dataset": "CIFAR-10",
    "optimizer": "AdamW",
    "scheduler": "CosineAnnealingLR",
    "loss": "NT-Xent",
    "warmup_epochs": 10,
}

for name1, transform1 in augmentations.items():
    for name2, transform2 in augmentations.items():
        if name1 == name2:
            config["transform"] = name1
            run_experiment(config, transforms.Compose([
                transform1,
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ]))
        else:
            config["transform"] = name1 + "_and_" + name2
            run_experiment(config, transforms.Compose([
                transform1,
                transform2,
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ]))


  if not hasattr(numpy, tp_name):
  if not hasattr(numpy, tp_name):
  "lr_options": generate_power_seq(LEARNING_RATE_CIFAR, 11),
  contrastive_task: Union[FeatureMapContrastiveTask] = FeatureMapContrastiveTask("01, 02, 11"),
  self.nce_loss = AmdimNCELoss(tclip)


Training on mps
Files already downloaded and verified


  scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=config["warmup_epochs"], warmup_start_lr=config["learning_rate"] * 1/10, max_epochs=config["epochs"], eta_min=0)
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mstevan-matovic[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1 finished in 73.23056888580322 seconds, Average Loss: 7.639720380306244
Epoch 2 finished in 65.41478681564331 seconds, Average Loss: 7.149856885274251
Epoch 3 finished in 67.92625713348389 seconds, Average Loss: 6.979749242464702
Epoch 4 finished in 59.66105008125305 seconds, Average Loss: 6.888481259346008
Epoch 5 finished in 57.16447424888611 seconds, Average Loss: 6.830951015154521
Epoch 6 finished in 56.95068717002869 seconds, Average Loss: 6.792366723219554
Epoch 7 finished in 56.857964754104614 seconds, Average Loss: 6.768590688705444
Epoch 8 finished in 56.52297401428223 seconds, Average Loss: 6.7470338344573975
Epoch 9 finished in 56.642690896987915 seconds, Average Loss: 6.733247558275859
Epoch 10 finished in 56.65328288078308 seconds, Average Loss: 6.714428742726644
Epoch 11 finished in 56.401756048202515 seconds, Average Loss: 6.701370457808177
Epoch 12 finished in 56.55287504196167 seconds, Average Loss: 6.691389342149098
Epoch 13 finished in 56.443639039993286 secon



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,█▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,6.51914


Training completed and logs saved.
Files already downloaded and verified


  scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=config["warmup_epochs"], warmup_start_lr=config["learning_rate"] * 1/10, max_epochs=config["epochs"], eta_min=0)


Epoch 1 finished in 64.12334609031677 seconds, Average Loss: 7.8809581597646075
Epoch 2 finished in 63.78705406188965 seconds, Average Loss: 7.512868881225586
Epoch 3 finished in 63.828185081481934 seconds, Average Loss: 7.361448387304942
Epoch 4 finished in 63.617778062820435 seconds, Average Loss: 7.284385820229848
Epoch 5 finished in 64.01339626312256 seconds, Average Loss: 7.23276025056839
Epoch 6 finished in 63.590866804122925 seconds, Average Loss: 7.1968793869018555
Epoch 7 finished in 63.967740297317505 seconds, Average Loss: 7.165963311990102
Epoch 8 finished in 64.11494207382202 seconds, Average Loss: 7.140089909235637
Epoch 9 finished in 64.20772409439087 seconds, Average Loss: 7.1170707543691
Epoch 10 finished in 63.92160487174988 seconds, Average Loss: 7.095264752705892
Epoch 11 finished in 64.0262598991394 seconds, Average Loss: 7.072635571161906
Epoch 12 finished in 64.26510190963745 seconds, Average Loss: 7.057318568229675
Epoch 13 finished in 64.65180015563965 seconds,

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))



0,1
loss,█▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,6.70091


Training completed and logs saved.
Files already downloaded and verified


  scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=config["warmup_epochs"], warmup_start_lr=config["learning_rate"] * 1/10, max_epochs=config["epochs"], eta_min=0)


Epoch 1 finished in 79.09375405311584 seconds, Average Loss: 7.618877291679382
Epoch 2 finished in 76.70164394378662 seconds, Average Loss: 7.126332461833954
Epoch 3 finished in 78.1787121295929 seconds, Average Loss: 6.969899157683055
Epoch 4 finished in 77.82795906066895 seconds, Average Loss: 6.879142880439758
Epoch 5 finished in 79.54167127609253 seconds, Average Loss: 6.830473442872365
Epoch 6 finished in 79.62723207473755 seconds, Average Loss: 6.790880044301351
Epoch 7 finished in 79.19032907485962 seconds, Average Loss: 6.764280537764232
Epoch 8 finished in 78.93450808525085 seconds, Average Loss: 6.7430077989896136
Epoch 9 finished in 79.53180432319641 seconds, Average Loss: 6.729019582271576
Epoch 10 finished in 79.00721406936646 seconds, Average Loss: 6.71087114016215
Epoch 11 finished in 78.33979511260986 seconds, Average Loss: 6.696597874164581
Epoch 12 finished in 77.17670392990112 seconds, Average Loss: 6.688126703103383
Epoch 13 finished in 78.64353203773499 seconds, Av



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,█▄▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,6.52466


Training completed and logs saved.
Files already downloaded and verified


  scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=config["warmup_epochs"], warmup_start_lr=config["learning_rate"] * 1/10, max_epochs=config["epochs"], eta_min=0)


Epoch 1 finished in 59.40078401565552 seconds, Average Loss: 7.666528304417928
Epoch 2 finished in 58.28634285926819 seconds, Average Loss: 7.187915802001953
Epoch 3 finished in 58.58825087547302 seconds, Average Loss: 7.01009076833725
Epoch 4 finished in 58.07065510749817 seconds, Average Loss: 6.9075372616449995
Epoch 5 finished in 57.965107917785645 seconds, Average Loss: 6.845897654692332
Epoch 6 finished in 59.03634810447693 seconds, Average Loss: 6.805721978346507
Epoch 7 finished in 59.7060809135437 seconds, Average Loss: 6.779391229152679
Epoch 8 finished in 60.85646176338196 seconds, Average Loss: 6.756242334842682
Epoch 9 finished in 60.74258303642273 seconds, Average Loss: 6.736295759677887
Epoch 10 finished in 61.50138306617737 seconds, Average Loss: 6.720192809899648
Epoch 11 finished in 59.99506711959839 seconds, Average Loss: 6.707371989885966
Epoch 12 finished in 60.78270101547241 seconds, Average Loss: 6.693694730599721
Epoch 13 finished in 59.64031171798706 seconds, A

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))



0,1
loss,█▄▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,6.53176


Training completed and logs saved.
Files already downloaded and verified


  scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=config["warmup_epochs"], warmup_start_lr=config["learning_rate"] * 1/10, max_epochs=config["epochs"], eta_min=0)


Epoch 1 finished in 168.95126914978027 seconds, Average Loss: 7.7206191420555115
Epoch 2 finished in 155.82712411880493 seconds, Average Loss: 7.209270636240642
Epoch 3 finished in 141.6147198677063 seconds, Average Loss: 7.00777268409729
Epoch 4 finished in 159.22644186019897 seconds, Average Loss: 6.902894715468089
Epoch 5 finished in 152.84524488449097 seconds, Average Loss: 6.839805046717326
Epoch 6 finished in 150.31048917770386 seconds, Average Loss: 6.799568255742391
Epoch 7 finished in 154.7208321094513 seconds, Average Loss: 6.771305282910665
Epoch 8 finished in 147.2647430896759 seconds, Average Loss: 6.745318313439687
Epoch 9 finished in 149.5320246219635 seconds, Average Loss: 6.728159685929616
Epoch 10 finished in 144.3449981212616 seconds, Average Loss: 6.7134332458178205
Epoch 11 finished in 143.3980541229248 seconds, Average Loss: 6.6984557112058
Epoch 12 finished in 144.52461290359497 seconds, Average Loss: 6.688430110613505
Epoch 13 finished in 147.45845818519592 seco

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))



0,1
loss,█▄▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,6.53037


Training completed and logs saved.
Files already downloaded and verified


  scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=config["warmup_epochs"], warmup_start_lr=config["learning_rate"] * 1/10, max_epochs=config["epochs"], eta_min=0)


Epoch 1 finished in 65.19251203536987 seconds, Average Loss: 7.877552390098572
Epoch 2 finished in 65.18176412582397 seconds, Average Loss: 7.512333472569783
Epoch 3 finished in 67.58386778831482 seconds, Average Loss: 7.378252625465393
Epoch 4 finished in 80.57423520088196 seconds, Average Loss: 7.298437535762787
Epoch 5 finished in 89.68061709403992 seconds, Average Loss: 7.242946366469066
Epoch 6 finished in 78.5651581287384 seconds, Average Loss: 7.199544648329417
Epoch 7 finished in 74.99510097503662 seconds, Average Loss: 7.168028354644775
Epoch 8 finished in 72.92862701416016 seconds, Average Loss: 7.135475397109985
Epoch 9 finished in 73.32662987709045 seconds, Average Loss: 7.112058718999227
Epoch 10 finished in 73.6548490524292 seconds, Average Loss: 7.080657382806142
Epoch 11 finished in 74.32603597640991 seconds, Average Loss: 7.062133689721425
Epoch 12 finished in 73.35431385040283 seconds, Average Loss: 7.042662700017293
Epoch 13 finished in 74.29317688941956 seconds, Ave

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))



0,1
loss,█▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,6.70095


Training completed and logs saved.
Files already downloaded and verified


  scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=config["warmup_epochs"], warmup_start_lr=config["learning_rate"] * 1/10, max_epochs=config["epochs"], eta_min=0)


Epoch 1 finished in 63.06705904006958 seconds, Average Loss: 7.652851541837056
Epoch 2 finished in 61.95271587371826 seconds, Average Loss: 7.124851047992706
Epoch 3 finished in 61.73205900192261 seconds, Average Loss: 6.891030251979828
Epoch 4 finished in 61.99229407310486 seconds, Average Loss: 6.707362393538157
Epoch 5 finished in 61.63967990875244 seconds, Average Loss: 6.604682425657908
Epoch 6 finished in 61.689085960388184 seconds, Average Loss: 6.549993336200714
Epoch 7 finished in 62.120986223220825 seconds, Average Loss: 6.517454703648885
Epoch 8 finished in 61.614951848983765 seconds, Average Loss: 6.495660146077474
Epoch 9 finished in 62.872801065444946 seconds, Average Loss: 6.478471120198567
Epoch 10 finished in 61.74997305870056 seconds, Average Loss: 6.464324057102203
Epoch 11 finished in 61.69476890563965 seconds, Average Loss: 6.453449110190074
Epoch 12 finished in 62.56964087486267 seconds, Average Loss: 6.44720321893692
Epoch 13 finished in 62.75314807891846 seconds



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,█▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,6.3853


Training completed and logs saved.
Files already downloaded and verified


  scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=config["warmup_epochs"], warmup_start_lr=config["learning_rate"] * 1/10, max_epochs=config["epochs"], eta_min=0)


Epoch 1 finished in 84.73026871681213 seconds, Average Loss: 7.719161589940389
Epoch 2 finished in 86.92192268371582 seconds, Average Loss: 7.134414434432983
Epoch 3 finished in 84.03852391242981 seconds, Average Loss: 6.843050042788188
Epoch 4 finished in 83.27088570594788 seconds, Average Loss: 6.725333412488301
Epoch 5 finished in 84.79025411605835 seconds, Average Loss: 6.632374405860901
Epoch 6 finished in 84.05186104774475 seconds, Average Loss: 6.561048169930776
Epoch 7 finished in 83.91335487365723 seconds, Average Loss: 6.5234248240788775
Epoch 8 finished in 84.04877209663391 seconds, Average Loss: 6.501716951529185
Epoch 9 finished in 84.18213820457458 seconds, Average Loss: 6.484394033749898
Epoch 10 finished in 84.07227921485901 seconds, Average Loss: 6.47136648495992
Epoch 11 finished in 84.67604088783264 seconds, Average Loss: 6.460646351178487
Epoch 12 finished in 85.01595902442932 seconds, Average Loss: 6.451785524686177
Epoch 13 finished in 85.30078530311584 seconds, A