In [None]:
from main import model_config, tokenizer, device
from config import extra
from dataset import get_dataloaders
from models.multi_model import CLIPVisionToPhi


In [None]:
train_dl, val_dl = get_dataloaders("data", tokenizer)

In [None]:
import torch.optim as optim
import torch

model = CLIPVisionToPhi(model_config)

model = model.to(device)
model.train()

for param in model.phi_model.parameters():
    if param.requires_grad:
        print(True)
        break

optimizer = optim.Adam(model.parameters(), lr=0.001)

total_epochs = 15

epoch_loss = []

print('---->>>>> Training logs <<<<<-----')
for epoch in range(total_epochs):
    data_iter = iter(train_dl)
    train_batch = next(data_iter)
    while train_batch:
        optimizer.zero_grad()
        image_feature = train_batch['image_feature']
        caption_ids = train_batch['decoder_caption']
        decoder_mask = train_batch['mask']
    
        label = train_batch['label']
    
        output = model(
            image_feature=image_feature.to(device),
            caption_ids=caption_ids.to(device),
            label=label.to(device),
            mask=decoder_mask.to(device)
        )
    
        loss = output['loss']
        loss.backward()

        epoch_loss.append(loss.item())

        optimizer.step()
        train_batch = next(data_iter)
        
    b = torch.tensor(epoch_loss, dtype=torch.float32)
    print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(b.mean()))
    epoch_loss = []
    
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': b.mean(),
            }, 'checkpoints/ckpt_%s.pth' % epoch)
    
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.vision_projector.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': b.mean(),
            }, 'checkpoints/vp_ckpt_%s.pth' % epoch)



In [None]:
import pytorch_lightning as pl
trainer = pl.Trainer(
            max_epochs=extra['num_epochs'],
            accelerator='gpu'
        )


In [None]:
trainer.fit(model, train_dl, val_dl)