In [79]:
import torch
import numpy as np


In [80]:
from datasets.utils.base_dataset import BaseDataset, KAND_get_loader
from datasets.utils.kand_creation import miniKAND_Dataset
from backbones.kand_encoder import  TripleMLP
import time

class MiniKandinsky(BaseDataset):
    NAME = 'minikandinsky'

    def get_data_loaders(self):
        start = time.time()

        dataset_train = miniKAND_Dataset(base_path='data/kand-3k',split='train') 
        dataset_val   = miniKAND_Dataset(base_path='data/kand-3k',split='val')      
        dataset_test  = miniKAND_Dataset(base_path='data/kand-3k',split='test') 
        # dataset_ood   = KAND_Dataset(base_path='data/kandinsky/data',split='ood') 

        # dataset_train.mask_concepts('red-and-squares')

        print(f'Loaded datasets in {time.time()-start} s.')        

        print('Len loaders: \n train:', len(dataset_train), '\n val:', len(dataset_val))
        print(' len test:', len(dataset_test)) #, '\n len ood', len(dataset_ood))

        
        train_loader = KAND_get_loader(dataset_train, 64, val_test=True)
        val_loader   = KAND_get_loader(dataset_val,   500, val_test=True)
        test_loader  = KAND_get_loader(dataset_test,  500, val_test=True)
        
        # self.ood_loader = get_loader(dataset_ood,  self.args.batch_size, val_test=True)

        return train_loader, val_loader, test_loader

    def get_backbone(self):
        return TripleMLP(latent_dim=6), 0
        #return TripleCNNEncoder(latent_dim=6), 0
    
    def get_split(self):
        return 3, ()

In [115]:
dset = MiniKandinsky(None)
encoder, _ = dset.get_backbone()

In [116]:
dataset_train = miniKAND_Dataset(base_path='data/kand-3k',split='train') 
dataset_val   = miniKAND_Dataset(base_path='data/kand-3k',split='val')      
dataset_test  = miniKAND_Dataset(base_path='data/kand-3k',split='test') 

# dataset_train.list_images = dataset_train.list_images[0]
# dataset_train.concepts = dataset_train.concepts[0]
# dataset_train.labels = dataset_train.labels[0]

train_loader = KAND_get_loader(dataset_train, 64, val_test=True)
val_loader   = KAND_get_loader(dataset_val,   500, val_test=True)
test_loader  = KAND_get_loader(dataset_test,  500, val_test=True)


In [117]:
# print(dataset_train.list_images )
# print(dataset_train.concepts )
# print(dataset_train.labels )

In [118]:
device = torch.device('cuda')

In [119]:
encoder = encoder.to(device)

In [120]:
from torch.optim import Adam, SGD

opt = Adam(encoder.parameters(), lr=1e-3)

In [121]:
for name, param in encoder.named_parameters():
    print(name)

backbone.1.weight
backbone.1.bias
backbone.4.weight
backbone.4.bias
backbone.7.weight
backbone.7.bias


In [122]:
for epoch in range(20):
    for iter, data in enumerate(train_loader):
        img, label, concept = data
        img, label, concept = img.to(device), label.to(device), concept.to(device)

        opt.zero_grad()
        
        xs = torch.split(img, 28*3, dim=-1)
        preds = []
        for i in range(len(xs)):
            out, _ = encoder(xs[i])    
            preds.append(out.unsqueeze(1))
        preds = torch.cat(preds, dim=1)

        # print(iter)

        loss = 0
        for i in range(3):
            c = preds[:, i, :]
            g = concept[:, i, :]

            cs = torch.split(c, 3, dim=-1)
            gs = torch.split(g, 1, dim=-1)

            assert len(cs) == len(gs), (cs[0].shape, gs[0].shape)
            
            # for j in range(3):
            j = 2
            # loss += torch.nn.functional.cross_entropy(cs[2*j],   gs[j].view(-1)) / 3
            loss += torch.nn.functional.cross_entropy(cs[2*j+1], gs[3+j].view(-1)) / 3

        
        if iter % 100 == 0:
            print('Epoch',epoch, 'Iter:', iter, 'Loss:', loss.item())
        
        loss.backward()
        opt.step()

Epoch 0 Iter: 0 Loss: 1.1249303817749023
Epoch 1 Iter: 0 Loss: 0.004875645507127047
Epoch 2 Iter: 0 Loss: 0.00699184276163578
Epoch 3 Iter: 0 Loss: 0.005085902288556099
Epoch 4 Iter: 0 Loss: 0.005486843176186085
Epoch 5 Iter: 0 Loss: 0.004140877164900303
Epoch 6 Iter: 0 Loss: 0.005515359342098236
Epoch 7 Iter: 0 Loss: 0.00797035451978445
Epoch 8 Iter: 0 Loss: 0.0092710480093956
Epoch 9 Iter: 0 Loss: 0.005397433415055275
Epoch 10 Iter: 0 Loss: 0.007073335349559784
Epoch 11 Iter: 0 Loss: 0.009108457714319229
Epoch 12 Iter: 0 Loss: 0.011702841147780418
Epoch 13 Iter: 0 Loss: 0.006490848958492279
Epoch 14 Iter: 0 Loss: 0.009613808244466782
Epoch 15 Iter: 0 Loss: 0.0061861551366746426
Epoch 16 Iter: 0 Loss: 0.006716057658195496
Epoch 17 Iter: 0 Loss: 0.006654723547399044
Epoch 18 Iter: 0 Loss: 0.008405258879065514
Epoch 19 Iter: 0 Loss: 0.007115951273590326


Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0
)