In [1]:
import torch.utils.data
from torch.utils.data import DataLoader

from encoder import Encoder
from custom_dataset import ContrastiveLearningDataset
from custom_loss import contrastive_loss
from torch import optim

In [2]:
torch.manual_seed(0)

model = Encoder()
# dataset = torch.utils.data.ConcatDataset([generated_dataset, Rodrigo_dataset])
optimizer = optim.Adam(model.parameters(), lr=0.01, eps=0.001)
# optimizer = optim.RMSprop(model.parameters())
dataset = ContrastiveLearningDataset('cropped_images')
BATCH_SIZE = 64
train_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [3]:
torch.cuda.is_available()

True

In [4]:
torch.manual_seed(0)
epochs = 3

for epoch in range(epochs):
    batch_loss = 0
    for i, batch in enumerate(train_dataloader):
        original, augmented = batch['original'], batch['augmented']
        original = original.to(device)
        augmented = augmented.to(device)

        optimizer.zero_grad()

        original_embeddings, _ = model(original)
        augmented_embeddings, _ = model(augmented)
        
        flattened_original = original_embeddings.reshape(original_embeddings.shape[0] * original_embeddings.shape[1], -1)
        flattened_augmented = augmented_embeddings.reshape(original_embeddings.shape[0] * original_embeddings.shape[1], -1)

        loss = contrastive_loss(flattened_original, flattened_augmented)
        batch_loss += loss.item()
        if i % 100 == 0:
            print(f"Epoch {epoch + 1}/{epochs}, Batch {i + 1}/{len(train_dataloader)}, Loss:{loss.item()}")
        loss.backward()
        optimizer.step()
    # print(batch_loss)
    print(f"Epoch {epoch + 1}/{epochs}, Average Loss:{batch_loss / len(train_dataloader)}")

Epoch 1/3, Batch 1/1747, Loss:358.8406982421875
Epoch 1/3, Batch 101/1747, Loss:311.3273620605469
Epoch 1/3, Batch 201/1747, Loss:310.59954833984375
Epoch 1/3, Batch 301/1747, Loss:310.42193603515625
Epoch 1/3, Batch 401/1747, Loss:310.3695068359375
Epoch 1/3, Batch 501/1747, Loss:310.1219482421875
Epoch 1/3, Batch 601/1747, Loss:310.1600341796875
Epoch 1/3, Batch 701/1747, Loss:310.3951416015625
Epoch 1/3, Batch 801/1747, Loss:310.0890808105469
Epoch 1/3, Batch 901/1747, Loss:310.21746826171875
Epoch 1/3, Batch 1001/1747, Loss:310.4662780761719
Epoch 1/3, Batch 1101/1747, Loss:310.1789245605469
Epoch 1/3, Batch 1201/1747, Loss:310.03082275390625
Epoch 1/3, Batch 1301/1747, Loss:310.140625
Epoch 1/3, Batch 1401/1747, Loss:309.991943359375
Epoch 1/3, Batch 1501/1747, Loss:310.0330810546875
Epoch 1/3, Batch 1601/1747, Loss:310.13360595703125
Epoch 1/3, Batch 1701/1747, Loss:310.19140625
Epoch 1/3, Average Loss:310.4442180421739
Epoch 2/3, Batch 1/1747, Loss:310.020263671875
Epoch 2/3, Ba

In [5]:
torch.save(model.state_dict(), 'encoder.pt')