In [1]:
import torch.utils.data
from torch.utils.data import DataLoader

from encoder import Encoder
from custom_dataset import ContrastiveLearningDataset
from custom_loss import contrastive_loss
from torch import optim

In [2]:
torch.manual_seed(0)

model = Encoder()
# dataset = torch.utils.data.ConcatDataset([generated_dataset, Rodrigo_dataset])
optimizer = optim.Adam(model.parameters(), lr=0.01, eps=0.001)
# optimizer = optim.RMSprop(model.parameters())
dataset = ContrastiveLearningDataset('./generated_dataset', max_size=20000,  crop_height=40)
train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [3]:
torch.cuda.is_available()

True

In [4]:
from tqdm import tqdm

torch.manual_seed(0)
epochs = 8

for epoch in tqdm(range(epochs)):
    batch_loss = 0
    for batch in train_dataloader:
        original, augmented = batch['original'], batch['augmented']
        original = original.to(device)
        augmented = augmented.to(device)

        optimizer.zero_grad()

        original_embeddings, _ = model(original)
        augmented_embeddings, _ = model(augmented)
        
        flattened_original = original_embeddings.reshape(original_embeddings.shape[0] * original_embeddings.shape[1], -1)
        flattened_augmented = augmented_embeddings.reshape(original_embeddings.shape[0] * original_embeddings.shape[1], -1)

        loss = contrastive_loss(flattened_original, flattened_augmented)
        batch_loss = loss.item()
        # print(loss.item())
        loss.backward()
        optimizer.step()
    print(batch_loss)

 12%|█▎        | 1/8 [03:15<22:45, 195.13s/it]

17692.66796875


 25%|██▌       | 2/8 [06:29<19:27, 194.60s/it]

17657.69140625


 38%|███▊      | 3/8 [09:43<16:12, 194.48s/it]

17648.91796875


 50%|█████     | 4/8 [12:57<12:57, 194.32s/it]

17647.0546875


 62%|██████▎   | 5/8 [16:11<09:42, 194.21s/it]

17646.375


 75%|███████▌  | 6/8 [19:26<06:28, 194.23s/it]

17646.25


 88%|████████▊ | 7/8 [22:39<03:14, 194.04s/it]

17646.1875


100%|██████████| 8/8 [25:53<00:00, 194.15s/it]

17646.146484375





In [5]:
torch.save(model.state_dict(), 'encoder.pt')

In [6]:
!ls

 char_to_token.pkl	  make_dataset.ipynb
 custom_dataset.py	  Padilla
 custom_loss.py		 'Padilla - 1 Nobleza virtuosa_testTranscription.docx'
 data_processing.ipynb	 'Padilla - Nobleza virtuosa_testExtract.pdf'
 Decoder.py		  __pycache__
 decoder_training.ipynb   ResNet.py
 encoder.pt		  Rodrigo
 encoder.py		  test.png
 encoder_training.ipynb   token_to_char.pkl
 generated_dataset
