In [1]:
import os, json
import numpy as np
import models
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from engine import train, validate, test
from dataset import ClevrPOCDataSet
from torch.utils.data import DataLoader
from transformers import CLIPProcessor, CLIPModel

import pickle


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DATA_FOLDER_NAME = 'output-12000'
DATA_FOLDER = '/home/marjan/code/CLEVR-POC/clevr-poc-dataset-gen/' + DATA_FOLDER_NAME + '/incomplete'
ENVIRONMENT_FOLDER = '/home/marjan/code/CLEVR-POC/clevr-poc-dataset-gen/environment_constraints'

In [3]:

with open(os.path.join('/home/marjan/code/CLEVR-POC/clevr-poc-dataset-gen', 'data', 'properties.json'), encoding="utf-8") as f:
    properties = json.load(f)


key_properties_values = []
for key_property in properties:
    if key_property == 'regions':
        continue
    key_properties_values.extend(properties[key_property].keys())

total_labels_to_index = {k: v for v, k in enumerate(key_properties_values)}
total_labels_to_index


{'cube': 0,
 'sphere': 1,
 'cylinder': 2,
 'cone': 3,
 'gray': 4,
 'red': 5,
 'blue': 6,
 'green': 7,
 'brown': 8,
 'purple': 9,
 'cyan': 10,
 'yellow': 11,
 'rubber': 12,
 'metal': 13,
 'large': 14,
 'medium': 15,
 'small': 16}

In [4]:

matplotlib.style.use('ggplot')
# initialize the computation device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_preprocess = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
#clip_model_path = "openai/clip-vit-base-patch32"
#clip_model, clip_preprocess = clip.load('ViT-B/32', device)


In [5]:
#intialize the model

clip_embedding_dim = 512
env_embedding_dim = 768
clip_model, final_classifier = models.model(requires_grad=False, 
                                      clip_model = clip_model,
                                      #checkpoint=clip_model_path,
                                      clip_embedding_dim=clip_embedding_dim,
                                      env_embedding_dim = env_embedding_dim,
                                      output_dim=len(total_labels_to_index))

clip_model.to(device)
final_classifier.to(device)

dropout = nn.Dropout(0.1) # ????

# learning parameters
lr = 0.001
epochs = 200
batch_size = 8
optimizer = optim.Adam(final_classifier.parameters(), lr=lr)
criterion = nn.BCELoss()
dropout = nn.Dropout(0.1)

validation_threshold = 0.5



In [6]:
# train dataset
train_data = ClevrPOCDataSet(DATA_FOLDER, 'training', total_labels_to_index, ENVIRONMENT_FOLDER)

# validation dataset
valid_data = ClevrPOCDataSet(DATA_FOLDER, 'validation', total_labels_to_index, ENVIRONMENT_FOLDER)

# train data loader
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

# validation data loader
valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=False)

print('a')

a


In [7]:
# start the training and validation
train_loss = []
valid_loss = []
valid_acc = []
best_validation_acc = 0
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss = train(final_classifier, clip_model, train_loader, optimizer, criterion, train_data, device, dropout, clip_preprocess)
    valid_epoch_loss, valid_epoch_acc = validate(final_classifier, clip_model, valid_loader, criterion, valid_data, device, dropout, clip_preprocess, validation_threshold)
    if best_validation_acc < valid_epoch_acc:
        best_validation_acc = valid_epoch_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': final_classifier.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': criterion,
            }, 'outputs/best_model_' + DATA_FOLDER_NAME + '.pth')

    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)
    valid_acc.append(valid_epoch_acc)
    
    with open('outputs/train_loss_' + DATA_FOLDER_NAME + '.pickle', 'wb') as f:
        pickle.dump(train_loss, f)
    with open('outputs/val_loss_' + DATA_FOLDER_NAME + '.pickle', 'wb') as f:
        pickle.dump(valid_loss, f)
    with open('outputs/val_acc_' + DATA_FOLDER_NAME + '.pickle', 'wb') as f:
        pickle.dump(valid_acc, f)        
     
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f'Val Loss: {valid_epoch_loss:.4f}')
    print(f'Val Acc: {valid_epoch_acc:.4f}')

Epoch 1 of 200
Training


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1500/1500 [05:23<00:00,  4.64it/s]


Validating


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 150/150 [00:11<00:00, 12.69it/s]


Train Loss: 0.3975
Val Loss: 0.3595
Val Acc: 0.8412
Epoch 2 of 200
Training


 14%|██████████████████▊                                                                                                                      | 206/1500 [00:16<01:41, 12.69it/s]


KeyboardInterrupt: 

In [None]:
#torch.save({
#            'epoch': epochs,
#            'model_state_dict': final_classifier.state_dict(),
#            'optimizer_state_dict': optimizer.state_dict(),
#            'loss': criterion,
#            }, 'outputs/last_model_' + DATA_FOLDER_NAME + '.pth')


# plot and save the train and validation line graphs
plt.figure(figsize=(10, 7))
plt.plot(train_loss, color='orange', label='train loss')
plt.plot(valid_loss, color='red', label='validataion loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('outputs/loss_' + DATA_FOLDER_NAME + '.png')
plt.show()

In [None]:
checkpoint = torch.load('outputs/best_model_' + DATA_FOLDER_NAME + '.pth')
# load model weights state_dict
final_classifier.load_state_dict(checkpoint['model_state_dict'])

# test dataset
test_data = ClevrPOCDataSet(DATA_FOLDER, 'testing', total_labels_to_index, ENVIRONMENT_FOLDER)
# test data loader
test_loader = DataLoader(test_data, batch_size=1, shuffle=False)


test_exact_acc, test_partial_acc = test(final_classifier, clip_model, test_loader, criterion, test_data, device, dropout, clip_preprocess, validation_threshold)
print('test_exact_acc', test_exact_acc)
print('test_partial_acc', test_partial_acc)


In [None]:
print(3)

In [None]:
!pwd