In [1]:
import json

import torch
import torch.nn as nn

import torch.utils.data
from torch.utils.data import DataLoader, ConcatDataset

from encoder import Encoder
from custom_dataset import ContrastiveLearningDataset
from custom_loss import contrastive_loss
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

config = json.load(open('config.json', 'r'))["SSL"]

In [2]:
torch.manual_seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Encoder()
model = model.to(device)
optimizer = Adam(model.parameters(), lr=config["start lr"])
scheduler = StepLR(optimizer, step_size=config["lr scheduler step size"], gamma=0.1)
dataset = []
for i in range(3):
	if config[f"dataset {i}"] is not None:
		dataset.append(ContrastiveLearningDataset(img_dir=config[f"dataset {i}"]))
dataset = ConcatDataset(dataset)
train_dataloader = DataLoader(dataset, batch_size=config["Batch size"], shuffle=True)

In [4]:
torch.manual_seed(0)
epochs = config["epoch size"]
step = 0
steps = []
loss_list = []

for epoch in range(epochs):
    batch_loss = 0
    for i, batch in enumerate(train_dataloader):
        original, augmented = batch['original'], batch['augmented']
        original = original.to(device)
        augmented = augmented.to(device)

        optimizer.zero_grad()

        original_embeddings, _ = model(original)
        augmented_embeddings, _ = model(augmented)
        
        avg_pool = nn.AdaptiveAvgPool2d((original_embeddings.shape[1] // 4, original_embeddings.shape[2]))

        original_embeddings = avg_pool(original_embeddings)
        augmented_embeddings = avg_pool(augmented_embeddings)

        flattened_original = original_embeddings.reshape(original_embeddings.shape[0] * original_embeddings.shape[1], -1)
        flattened_augmented = augmented_embeddings.reshape(original_embeddings.shape[0] * original_embeddings.shape[1], -1)

        loss = contrastive_loss(flattened_original, flattened_augmented)
        batch_loss += loss.item()
        if i % 100 == 0:
            print(f"Epoch {epoch + 1}/{epochs}, Batch {i + 1}/{len(train_dataloader)}, Loss:{loss.item()}")
            step += 100
            steps.append(step)
            loss_list.append(loss.item())
        loss.backward()
        optimizer.step()
        scheduler.step()
    print(f"Epoch {epoch + 1}/{epochs}, Average Loss:{batch_loss / len(train_dataloader)}")

Epoch 1/5, Batch 1/2481, Loss:407.0021667480469
Epoch 1/5, Batch 101/2481, Loss:350.645751953125
Epoch 1/5, Batch 201/2481, Loss:349.701171875
Epoch 1/5, Batch 301/2481, Loss:349.75555419921875
Epoch 1/5, Batch 401/2481, Loss:348.7194519042969
Epoch 1/5, Batch 501/2481, Loss:348.22564697265625
Epoch 1/5, Batch 601/2481, Loss:347.7273254394531
Epoch 1/5, Batch 701/2481, Loss:347.35858154296875
Epoch 1/5, Batch 801/2481, Loss:347.50958251953125
Epoch 1/5, Batch 901/2481, Loss:347.58740234375
Epoch 1/5, Batch 1001/2481, Loss:347.45751953125
Epoch 1/5, Batch 1101/2481, Loss:348.71856689453125
Epoch 1/5, Batch 1201/2481, Loss:347.76416015625
Epoch 1/5, Batch 1301/2481, Loss:347.580322265625
Epoch 1/5, Batch 1401/2481, Loss:347.369140625
Epoch 1/5, Batch 1501/2481, Loss:347.216552734375
Epoch 1/5, Batch 1601/2481, Loss:347.2215576171875
Epoch 1/5, Batch 1701/2481, Loss:347.15191650390625
Epoch 1/5, Batch 1801/2481, Loss:347.2649230957031
Epoch 1/5, Batch 1901/2481, Loss:347.0948181152344
Epo

In [None]:
from matplotlib import pyplot as plt
plt.plot(steps, loss_list, xlabel="Steps", ylabel="Loss")

In [5]:
torch.save(model.state_dict(), config["saved Encoder path"])