In [1]:
import torch.utils.data
from torch.utils.data import DataLoader

from encoder import Encoder
from custom_dataset import ContrastiveLearningDataset
from custom_loss import contrastive_loss
from torch import optim

In [2]:
torch.manual_seed(0)

model = Encoder()
# dataset = torch.utils.data.ConcatDataset([generated_dataset, Rodrigo_dataset])
optimizer = optim.Adam(model.parameters(), lr=0.01, eps=0.001)
# optimizer = optim.RMSprop(model.parameters())
book_dataset = ContrastiveLearningDataset('cropped_images_line')
generated_dataset = ContrastiveLearningDataset('generated_dataset')
dataset = torch.utils.data.ConcatDataset([book_dataset, generated_dataset])
BATCH_SIZE = 64
train_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [3]:
torch.cuda.is_available()

True

In [4]:
torch.manual_seed(0)
epochs = 5

for epoch in range(epochs):
    batch_loss = 0
    for i, batch in enumerate(train_dataloader):
        original, augmented = batch['original'], batch['augmented']
        original = original.to(device)
        augmented = augmented.to(device)

        optimizer.zero_grad()

        original_embeddings, _ = model(original)
        augmented_embeddings, _ = model(augmented)
        
        flattened_original = original_embeddings.reshape(original_embeddings.shape[0] * original_embeddings.shape[1], -1)
        flattened_augmented = augmented_embeddings.reshape(original_embeddings.shape[0] * original_embeddings.shape[1], -1)

        loss = contrastive_loss(flattened_original, flattened_augmented)
        batch_loss += loss.item()
        if i % 100 == 0:
            print(f"Epoch {epoch + 1}/{epochs}, Batch {i + 1}/{len(train_dataloader)}, Loss:{loss.item()}")
        loss.backward()
        optimizer.step()
    # print(batch_loss)
    print(f"Epoch {epoch + 1}/{epochs}, Average Loss:{batch_loss / len(train_dataloader)}")

Epoch 1/5, Batch 1/621, Loss:1967.94580078125
Epoch 1/5, Batch 101/621, Loss:1790.57373046875
Epoch 1/5, Batch 201/621, Loss:1766.6455078125
Epoch 1/5, Batch 301/621, Loss:1758.156982421875
Epoch 1/5, Batch 401/621, Loss:1753.1778564453125
Epoch 1/5, Batch 501/621, Loss:1746.778076171875
Epoch 1/5, Batch 601/621, Loss:1743.1485595703125
Epoch 1/5, Average Loss:1766.971390335648
Epoch 2/5, Batch 1/621, Loss:1745.473876953125
Epoch 2/5, Batch 101/621, Loss:1740.743408203125
Epoch 2/5, Batch 201/621, Loss:1738.405517578125
Epoch 2/5, Batch 301/621, Loss:1736.8748779296875
Epoch 2/5, Batch 401/621, Loss:1749.757568359375
Epoch 2/5, Batch 501/621, Loss:1784.7783203125
Epoch 2/5, Batch 601/621, Loss:1759.12109375
Epoch 2/5, Average Loss:1747.8783701283921
Epoch 3/5, Batch 1/621, Loss:1758.75244140625
Epoch 3/5, Batch 101/621, Loss:1757.3055419921875
Epoch 3/5, Batch 201/621, Loss:1755.847900390625
Epoch 3/5, Batch 301/621, Loss:1750.13623046875
Epoch 3/5, Batch 401/621, Loss:1748.75244140625

In [5]:
torch.save(model.state_dict(), 'encoder.pt')