In [1]:
import os, json
import numpy as np
import models
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from engine import train, validate
from dataset import ClevrPOCDataSet
from torch.utils.data import DataLoader
from transformers import CLIPProcessor, CLIPModel

import pickle


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DATA_FOLDER = '/home/marjan/myworks/code/python/CLEVR-POC/clevr-poc-dataset-gen/output-2000/incomplete'
ENVIRONMENT_FOLDER = '/home/marjan/myworks/code/python/CLEVR-POC/clevr-poc-dataset-gen/environment_constraints'

In [3]:

with open(os.path.join('/home/marjan/myworks/code/python/CLEVR-POC/clevr-poc-dataset-gen', 'data', 'properties.json'), encoding="utf-8") as f:
    properties = json.load(f)


key_properties_values = []
for key_property in properties:
    if key_property == 'regions':
        continue
    key_properties_values.extend(properties[key_property].keys())

total_labels_to_index = {k: v for v, k in enumerate(key_properties_values)}
total_labels_to_index


{'cube': 0,
 'sphere': 1,
 'cylinder': 2,
 'cone': 3,
 'red': 4,
 'blue': 5,
 'green': 6,
 'yellow': 7,
 'rubber': 8,
 'metal': 9,
 'large': 10,
 'small': 11}

In [4]:

matplotlib.style.use('ggplot')
# initialize the computation device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_preprocess = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
#clip_model_path = "openai/clip-vit-base-patch32"
#clip_model, clip_preprocess = clip.load('ViT-B/32', device)


In [5]:
#intialize the model

clip_embedding_dim = 512
env_embedding_dim = 768
clip_model, final_classifier = models.model(requires_grad=False, 
                                      clip_model = clip_model,
                                      #checkpoint=clip_model_path,
                                      clip_embedding_dim=clip_embedding_dim,
                                      env_embedding_dim = env_embedding_dim,
                                      output_dim=len(total_labels_to_index))

clip_model.to(device)
final_classifier.to(device)

dropout = nn.Dropout(0.1) # ????

# learning parameters
lr = 0.001
epochs = 75
batch_size = 50
optimizer = optim.Adam(final_classifier.parameters(), lr=lr)
criterion = nn.BCELoss()
dropout = nn.Dropout(0.1)



In [6]:
# train dataset
train_data = ClevrPOCDataSet(DATA_FOLDER, 'training', total_labels_to_index, ENVIRONMENT_FOLDER)

# validation dataset
valid_data = ClevrPOCDataSet(DATA_FOLDER, 'validation', total_labels_to_index, ENVIRONMENT_FOLDER)

# train data loader
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

# validation data loader
valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=False)

print('a')

a


In [7]:
# start the training and validation
train_loss = []
valid_loss = []
best_validation_loss = 100000
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss = train(final_classifier, clip_model, train_loader, optimizer, criterion, train_data, device, dropout, clip_preprocess)
    valid_epoch_loss = validate(final_classifier, clip_model, valid_loader, criterion, valid_data, device, dropout, clip_preprocess)
    if best_validation_loss > valid_epoch_loss:
        best_validation_loss = valid_epoch_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': final_classifier.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': criterion,
            }, 'outputs/best.pth')

    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)
    
    with open('outputs/train_loss.pickle', 'wb') as f:
        pickle.dump(train_loss, f)
    with open('outputs/valid_loss.pickle', 'wb') as f:
        pickle.dump(valid_loss, f)        
     
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f'Val Loss: {valid_epoch_loss:.4f}')

Epoch 1 of 20
Training


  0%|▏                                        | 1/250 [00:41<2:54:00, 41.93s/it]

CLIPOutput(loss=None, logits_per_image=tensor([[31.8987, 29.0988, 29.1892, 28.9908, 29.0488, 24.7924, 26.8405, 26.4829],
        [27.8955, 28.6767, 27.9665, 28.0333, 28.0630, 23.4235, 26.6942, 27.9352],
        [27.4727, 28.7350, 26.6611, 26.3450, 29.0572, 23.6613, 27.9477, 26.5683],
        [28.3115, 26.4011, 28.8673, 27.0557, 25.6802, 22.7268, 22.3783, 26.9871],
        [28.7283, 28.3139, 28.8242, 27.1617, 28.6656, 24.1339, 25.3894, 27.1665],
        [28.9097, 28.2662, 27.9589, 26.4085, 27.3518, 23.8704, 23.7378, 26.3052],
        [30.2006, 27.9623, 28.1390, 28.9324, 26.4059, 23.6170, 24.6701, 26.4913],
        [27.2569, 26.9734, 26.4685, 24.9924, 26.0242, 23.3909, 28.2577, 27.5934]]), logits_per_text=tensor([[31.8987, 27.8955, 27.4727, 28.3115, 28.7283, 28.9097, 30.2006, 27.2569],
        [29.0988, 28.6767, 28.7350, 26.4011, 28.3139, 28.2662, 27.9623, 26.9734],
        [29.1892, 27.9665, 26.6611, 28.8673, 28.8242, 27.9589, 28.1390, 26.4685],
        [28.9908, 28.0333, 26.3450, 27.05

  1%|▎                                        | 2/250 [01:36<3:22:57, 49.10s/it]

CLIPOutput(loss=None, logits_per_image=tensor([[23.8810, 22.1014, 27.9076, 27.4981, 28.5248, 27.0105, 29.9292, 29.2679],
        [29.4036, 24.8819, 25.3081, 29.4310, 30.7997, 23.9100, 28.1597, 25.0897],
        [24.6462, 22.1111, 25.5356, 28.1712, 30.2157, 27.5246, 26.2206, 24.5920],
        [22.8848, 22.0811, 25.3067, 26.9892, 28.4896, 30.4294, 28.0219, 28.8122],
        [26.1887, 24.8672, 25.2822, 25.5667, 30.0706, 23.7435, 25.9236, 22.2039],
        [24.0857, 22.7296, 24.6808, 29.6440, 28.7509, 29.2257, 27.2088, 27.8170],
        [27.2460, 21.9635, 26.1323, 26.8742, 28.9070, 23.3473, 28.9856, 26.5909],
        [22.1119, 20.8080, 24.7191, 25.7474, 24.0212, 28.3817, 29.0421, 32.1142]]), logits_per_text=tensor([[23.8810, 29.4036, 24.6462, 22.8848, 26.1887, 24.0857, 27.2460, 22.1119],
        [22.1014, 24.8819, 22.1111, 22.0811, 24.8672, 22.7296, 21.9635, 20.8080],
        [27.9076, 25.3081, 25.5356, 25.3067, 25.2822, 24.6808, 26.1323, 24.7191],
        [27.4981, 29.4310, 28.1712, 26.98

  1%|▍                                        | 3/250 [02:32<3:35:43, 52.40s/it]

CLIPOutput(loss=None, logits_per_image=tensor([[28.4308, 25.6515, 21.1811, 25.1366, 26.5666, 22.8279, 29.0997, 28.1967],
        [25.5363, 26.9084, 25.3206, 26.4227, 27.1173, 24.5823, 28.1218, 30.3174],
        [24.0046, 25.1560, 26.3949, 26.5920, 25.9867, 21.1354, 28.4098, 27.7670],
        [26.2880, 25.4948, 25.1282, 25.9059, 26.3293, 21.2251, 26.5607, 25.8212],
        [26.8044, 23.4348, 22.7779, 24.9719, 26.2798, 21.5931, 29.0795, 25.8003],
        [26.2296, 25.8263, 24.0603, 25.0135, 26.4013, 23.4776, 29.4516, 31.3278],
        [27.8087, 26.5516, 24.3993, 25.7062, 26.9322, 21.7807, 28.9932, 30.4230],
        [26.7796, 28.9283, 24.7123, 28.0310, 28.5863, 23.1710, 27.0767, 29.9217]]), logits_per_text=tensor([[28.4308, 25.5363, 24.0046, 26.2880, 26.8044, 26.2296, 27.8087, 26.7796],
        [25.6515, 26.9084, 25.1560, 25.4948, 23.4348, 25.8263, 26.5516, 28.9283],
        [21.1811, 25.3206, 26.3949, 25.1282, 22.7779, 24.0603, 24.3993, 24.7123],
        [25.1366, 26.4227, 26.5920, 25.90

  1%|▍                                        | 3/250 [02:47<3:49:21, 55.72s/it]


KeyboardInterrupt: 

In [None]:
# save the trained model to disk
torch.save({
            'epoch': epochs,
            'model_state_dict': final_classifier.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': criterion,
            }, 'outputs/last.pth')
# plot and save the train and validation line graphs
plt.figure(figsize=(10, 7))
plt.plot(train_loss, color='orange', label='train loss')
plt.plot(valid_loss, color='red', label='validataion loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('outputs/loss.png')
plt.show()

In [None]:
answer = ['red', 'yellow', 'small']
a = [total_labels_to_index[i] for i in answer]
b = [1 if i in a else 0 for i in range(len(total_labels_to_index))]

print(a)
one_hot_answer = torch.Tensor(b)
print(one_hot_answer)