In [1]:
import torch.utils.data
from torch.utils.data import DataLoader

from encoder import Encoder
from custom_dataset import ContrastiveLearningDataset
from custom_loss import contrastive_loss
from torch import optim

In [2]:
torch.manual_seed(0)

model = Encoder()
# dataset = torch.utils.data.ConcatDataset([generated_dataset, Rodrigo_dataset])
optimizer = optim.Adam(model.parameters(), lr=0.01, eps=0.001)
# optimizer = optim.RMSprop(model.parameters())
dataset = ContrastiveLearningDataset('cropped_images')
BATCH_SIZE = 64
train_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [3]:
torch.cuda.is_available()

True

In [4]:
torch.manual_seed(0)
epochs = 3

for epoch in range(epochs):
    batch_loss = 0
    for i, batch in enumerate(train_dataloader):
        original, augmented = batch['original'], batch['augmented']
        original = original.to(device)
        augmented = augmented.to(device)

        optimizer.zero_grad()

        original_embeddings, _ = model(original)
        augmented_embeddings, _ = model(augmented)
        
        flattened_original = original_embeddings.reshape(original_embeddings.shape[0] * original_embeddings.shape[1], -1)
        flattened_augmented = augmented_embeddings.reshape(original_embeddings.shape[0] * original_embeddings.shape[1], -1)

        loss = contrastive_loss(flattened_original, flattened_augmented)
        batch_loss += loss.item()
        if i % 100 == 0:
            print(f"Epoch {epoch + 1}/{epochs}, Batch {i + 1}/{len(train_dataloader)}, Loss:{loss.item()}")
        loss.backward()
        optimizer.step()
    # print(batch_loss)
    print(f"Epoch {epoch + 1}/{epochs}, Average Loss:{batch_loss / len(train_dataloader)}")

Epoch 1/3, Batch 1/1747, Loss:360.906494140625
Epoch 1/3, Batch 101/1747, Loss:312.74749755859375
Epoch 1/3, Batch 201/1747, Loss:311.70355224609375
Epoch 1/3, Batch 301/1747, Loss:311.18682861328125
Epoch 1/3, Batch 401/1747, Loss:311.18084716796875
Epoch 1/3, Batch 501/1747, Loss:310.6376953125
Epoch 1/3, Batch 601/1747, Loss:310.4555969238281
Epoch 1/3, Batch 701/1747, Loss:310.521728515625
Epoch 1/3, Batch 801/1747, Loss:310.435302734375
Epoch 1/3, Batch 901/1747, Loss:310.3323974609375
Epoch 1/3, Batch 1001/1747, Loss:310.37017822265625
Epoch 1/3, Batch 1101/1747, Loss:310.29913330078125
Epoch 1/3, Batch 1201/1747, Loss:310.4059753417969
Epoch 1/3, Batch 1301/1747, Loss:310.06011962890625
Epoch 1/3, Batch 1401/1747, Loss:310.087646484375
Epoch 1/3, Batch 1501/1747, Loss:310.4847717285156
Epoch 1/3, Batch 1601/1747, Loss:310.24285888671875
Epoch 1/3, Batch 1701/1747, Loss:310.0840759277344
Epoch 1/3, Average Loss:310.8630087054929
Epoch 2/3, Batch 1/1747, Loss:310.13836669921875
Ep

In [5]:
torch.save(model.state_dict(), 'encoder.pt')

In [6]:
!ls

 char_to_token.pkl	  make_dataset.ipynb
 custom_dataset.py	  Padilla
 custom_loss.py		 'Padilla - 1 Nobleza virtuosa_testTranscription.docx'
 data_processing.ipynb	 'Padilla - Nobleza virtuosa_testExtract.pdf'
 Decoder.py		  __pycache__
 decoder_training.ipynb   ResNet.py
 encoder.pt		  Rodrigo
 encoder.py		  test.png
 encoder_training.ipynb   token_to_char.pkl
 generated_dataset
